turbo_broccoli.custom.tensorflow

Tensorflow (de)serialization utilities.

  1"""Tensorflow (de)serialization utilities."""
  2
  3from typing import Any, Callable, Tuple
  4
  5import tensorflow as tf
  6from safetensors import tensorflow as st
  7
  8from ..context import Context
  9from ..exceptions import DeserializationError, TypeNotSupported
 10
 11
 12def _json_to_sparse_tensor(dct: dict, ctx: Context) -> tf.Tensor:
 13    decoders = {
 14        2: _json_to_sparse_tensor_v2,
 15    }
 16    return decoders[dct["__version__"]](dct, ctx)
 17
 18
 19def _json_to_sparse_tensor_v2(dct: dict, ctx: Context) -> tf.Tensor:
 20    return tf.SparseTensor(
 21        dense_shape=dct["shape"],
 22        indices=dct["indices"],
 23        values=dct["values"],
 24    )
 25
 26
 27def _json_to_tensor(dct: dict, ctx: Context) -> tf.Tensor:
 28    ctx.raise_if_nodecode("bytes")
 29    decoders = {
 30        4: _json_to_tensor_v4,
 31    }
 32    return decoders[dct["__version__"]](dct, ctx)
 33
 34
 35def _json_to_tensor_v4(dct: dict, ctx: Context) -> tf.Tensor:
 36    return st.load(dct["data"])["data"]
 37
 38
 39def _json_to_variable(dct: dict, ctx: Context) -> tf.Variable:
 40    decoders = {
 41        3: _json_to_variable_v3,
 42    }
 43    return decoders[dct["__version__"]](dct, ctx)
 44
 45
 46def _json_to_variable_v3(dct: dict, ctx: Context) -> tf.Variable:
 47    return tf.Variable(
 48        initial_value=dct["value"],
 49        name=dct["name"],
 50        trainable=dct["trainable"],
 51    )
 52
 53
 54def _ragged_tensor_to_json(obj: tf.Tensor, ctx: Context) -> dict:
 55    raise NotImplementedError(
 56        "Serialization of ragged tensors is not supported"
 57    )
 58
 59
 60def _sparse_tensor_to_json(obj: tf.SparseTensor, ctx: Context) -> dict:
 61    return {
 62        "__type__": "tensorflow.sparse_tensor",
 63        "__version__": 2,
 64        "indices": obj.indices,
 65        "shape": list(obj.dense_shape),
 66        "values": obj.values,
 67    }
 68
 69
 70def _tensor_to_json(obj: tf.Tensor, ctx: Context) -> dict:
 71    return {
 72        "__type__": "tensorflow.tensor",
 73        "__version__": 4,
 74        "data": st.save({"data": obj}),
 75    }
 76
 77
 78def _variable_to_json(var: tf.Variable, ctx: Context) -> dict:
 79    return {
 80        "__type__": "tensorflow.variable",
 81        "__version__": 3,
 82        "name": var.name,
 83        "value": var.value(),
 84        "trainable": var.trainable,
 85    }
 86
 87
 88def from_json(dct: dict, ctx: Context) -> Any:
 89    decoders = {
 90        "tensorflow.sparse_tensor": _json_to_sparse_tensor,
 91        "tensorflow.tensor": _json_to_tensor,
 92        "tensorflow.variable": _json_to_variable,
 93    }
 94    try:
 95        type_name = dct["__type__"]
 96        return decoders[type_name](dct, ctx)
 97    except KeyError as exc:
 98        raise DeserializationError() from exc
 99
100
101def to_json(obj: Any, ctx: Context) -> dict:
102    """
103    Serializes a tensorflow object into JSON by cases. See the README for the
104    precise list of supported types. The return dict has the following
105    structure:
106
107    - `tf.RaggedTensor`: Not supported.
108
109    - `tf.SparseTensor`:
110
111        ```py
112        {
113            "__type__": "tensorflow.sparse_tensor",
114            "__version__": 2,
115            "indices": {...},
116            "values": {...},
117            "shape": {...},
118        }
119        ```
120
121      where the first two `{...}` placeholders result in the serialization of
122      `tf.Tensor` (see below).
123
124    - other `tf.Tensor` subtypes:
125
126        ```py
127        {
128            "__type__": "tensorflow.tensor",
129            "__version__": 4,
130            "data": {
131                "__type__": "bytes",
132                ...
133            },
134        }
135        ```
136
137      see `turbo_broccoli.custom.bytes.to_json`.
138
139    - `tf.Variable`:
140
141        ```py
142        {
143            "__type__": "tensorflow.tensor",
144            "__version__": 3,
145            "name": <str>,
146            "value": {...},
147            "trainable": <bool>,
148        }
149        ```
150
151      where `{...}` is the document produced by serializing the value tensor of
152      the variable, see above.
153
154    """
155    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
156        (tf.RaggedTensor, _ragged_tensor_to_json),
157        (tf.SparseTensor, _sparse_tensor_to_json),
158        (tf.Tensor, _tensor_to_json),
159        (tf.Variable, _variable_to_json),
160    ]
161    for t, f in encoders:
162        if isinstance(obj, t):
163            return f(obj, ctx)
164    raise TypeNotSupported()
def from_json(dct: dict, ctx: turbo_broccoli.context.Context) -> Any:
89def from_json(dct: dict, ctx: Context) -> Any:
90    decoders = {
91        "tensorflow.sparse_tensor": _json_to_sparse_tensor,
92        "tensorflow.tensor": _json_to_tensor,
93        "tensorflow.variable": _json_to_variable,
94    }
95    try:
96        type_name = dct["__type__"]
97        return decoders[type_name](dct, ctx)
98    except KeyError as exc:
99        raise DeserializationError() from exc
def to_json(obj: Any, ctx: turbo_broccoli.context.Context) -> dict:
102def to_json(obj: Any, ctx: Context) -> dict:
103    """
104    Serializes a tensorflow object into JSON by cases. See the README for the
105    precise list of supported types. The return dict has the following
106    structure:
107
108    - `tf.RaggedTensor`: Not supported.
109
110    - `tf.SparseTensor`:
111
112        ```py
113        {
114            "__type__": "tensorflow.sparse_tensor",
115            "__version__": 2,
116            "indices": {...},
117            "values": {...},
118            "shape": {...},
119        }
120        ```
121
122      where the first two `{...}` placeholders result in the serialization of
123      `tf.Tensor` (see below).
124
125    - other `tf.Tensor` subtypes:
126
127        ```py
128        {
129            "__type__": "tensorflow.tensor",
130            "__version__": 4,
131            "data": {
132                "__type__": "bytes",
133                ...
134            },
135        }
136        ```
137
138      see `turbo_broccoli.custom.bytes.to_json`.
139
140    - `tf.Variable`:
141
142        ```py
143        {
144            "__type__": "tensorflow.tensor",
145            "__version__": 3,
146            "name": <str>,
147            "value": {...},
148            "trainable": <bool>,
149        }
150        ```
151
152      where `{...}` is the document produced by serializing the value tensor of
153      the variable, see above.
154
155    """
156    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
157        (tf.RaggedTensor, _ragged_tensor_to_json),
158        (tf.SparseTensor, _sparse_tensor_to_json),
159        (tf.Tensor, _tensor_to_json),
160        (tf.Variable, _variable_to_json),
161    ]
162    for t, f in encoders:
163        if isinstance(obj, t):
164            return f(obj, ctx)
165    raise TypeNotSupported()

Serializes a tensorflow object into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure:

  • tf.RaggedTensor: Not supported.

  • tf.SparseTensor:

    {
        "__type__": "tensorflow.sparse_tensor",
        "__version__": 2,
        "indices": {...},
        "values": {...},
        "shape": {...},
    }
    

    where the first two {...} placeholders result in the serialization of tf.Tensor (see below).

  • other tf.Tensor subtypes:

    {
        "__type__": "tensorflow.tensor",
        "__version__": 4,
        "data": {
            "__type__": "bytes",
            ...
        },
    }
    

    see turbo_broccoli.custom.bytes.to_json.

  • tf.Variable:

    {
        "__type__": "tensorflow.tensor",
        "__version__": 3,
        "name": <str>,
        "value": {...},
        "trainable": <bool>,
    }
    

    where {...} is the document produced by serializing the value tensor of the variable, see above.