turbo_broccoli.custom.tensorflow
Tensorflow (de)serialization utilities.
1"""Tensorflow (de)serialization utilities.""" 2 3from typing import Any, Callable, Tuple 4 5import tensorflow as tf 6from safetensors import tensorflow as st 7 8from ..context import Context 9from ..exceptions import DeserializationError, TypeNotSupported 10 11 12def _json_to_sparse_tensor(dct: dict, ctx: Context) -> tf.Tensor: 13 decoders = { 14 2: _json_to_sparse_tensor_v2, 15 } 16 return decoders[dct["__version__"]](dct, ctx) 17 18 19def _json_to_sparse_tensor_v2(dct: dict, ctx: Context) -> tf.Tensor: 20 return tf.SparseTensor( 21 dense_shape=dct["shape"], 22 indices=dct["indices"], 23 values=dct["values"], 24 ) 25 26 27def _json_to_tensor(dct: dict, ctx: Context) -> tf.Tensor: 28 ctx.raise_if_nodecode("bytes") 29 decoders = { 30 4: _json_to_tensor_v4, 31 } 32 return decoders[dct["__version__"]](dct, ctx) 33 34 35def _json_to_tensor_v4(dct: dict, ctx: Context) -> tf.Tensor: 36 return st.load(dct["data"])["data"] 37 38 39def _json_to_variable(dct: dict, ctx: Context) -> tf.Variable: 40 decoders = { 41 3: _json_to_variable_v3, 42 } 43 return decoders[dct["__version__"]](dct, ctx) 44 45 46def _json_to_variable_v3(dct: dict, ctx: Context) -> tf.Variable: 47 return tf.Variable( 48 initial_value=dct["value"], 49 name=dct["name"], 50 trainable=dct["trainable"], 51 ) 52 53 54def _ragged_tensor_to_json(obj: tf.Tensor, ctx: Context) -> dict: 55 raise NotImplementedError( 56 "Serialization of ragged tensors is not supported" 57 ) 58 59 60def _sparse_tensor_to_json(obj: tf.SparseTensor, ctx: Context) -> dict: 61 return { 62 "__type__": "tensorflow.sparse_tensor", 63 "__version__": 2, 64 "indices": obj.indices, 65 "shape": list(obj.dense_shape), 66 "values": obj.values, 67 } 68 69 70def _tensor_to_json(obj: tf.Tensor, ctx: Context) -> dict: 71 return { 72 "__type__": "tensorflow.tensor", 73 "__version__": 4, 74 "data": st.save({"data": obj}), 75 } 76 77 78def _variable_to_json(var: tf.Variable, ctx: Context) -> dict: 79 return { 80 "__type__": "tensorflow.variable", 81 "__version__": 3, 82 "name": var.name, 83 "value": var.value(), 84 "trainable": var.trainable, 85 } 86 87 88def from_json(dct: dict, ctx: Context) -> Any: 89 decoders = { 90 "tensorflow.sparse_tensor": _json_to_sparse_tensor, 91 "tensorflow.tensor": _json_to_tensor, 92 "tensorflow.variable": _json_to_variable, 93 } 94 try: 95 type_name = dct["__type__"] 96 return decoders[type_name](dct, ctx) 97 except KeyError as exc: 98 raise DeserializationError() from exc 99 100 101def to_json(obj: Any, ctx: Context) -> dict: 102 """ 103 Serializes a tensorflow object into JSON by cases. See the README for the 104 precise list of supported types. The return dict has the following 105 structure: 106 107 - `tf.RaggedTensor`: Not supported. 108 109 - `tf.SparseTensor`: 110 111 ```py 112 { 113 "__type__": "tensorflow.sparse_tensor", 114 "__version__": 2, 115 "indices": {...}, 116 "values": {...}, 117 "shape": {...}, 118 } 119 ``` 120 121 where the first two `{...}` placeholders result in the serialization of 122 `tf.Tensor` (see below). 123 124 - other `tf.Tensor` subtypes: 125 126 ```py 127 { 128 "__type__": "tensorflow.tensor", 129 "__version__": 4, 130 "data": { 131 "__type__": "bytes", 132 ... 133 }, 134 } 135 ``` 136 137 see `turbo_broccoli.custom.bytes.to_json`. 138 139 - `tf.Variable`: 140 141 ```py 142 { 143 "__type__": "tensorflow.tensor", 144 "__version__": 3, 145 "name": <str>, 146 "value": {...}, 147 "trainable": <bool>, 148 } 149 ``` 150 151 where `{...}` is the document produced by serializing the value tensor of 152 the variable, see above. 153 154 """ 155 encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [ 156 (tf.RaggedTensor, _ragged_tensor_to_json), 157 (tf.SparseTensor, _sparse_tensor_to_json), 158 (tf.Tensor, _tensor_to_json), 159 (tf.Variable, _variable_to_json), 160 ] 161 for t, f in encoders: 162 if isinstance(obj, t): 163 return f(obj, ctx) 164 raise TypeNotSupported()
89def from_json(dct: dict, ctx: Context) -> Any: 90 decoders = { 91 "tensorflow.sparse_tensor": _json_to_sparse_tensor, 92 "tensorflow.tensor": _json_to_tensor, 93 "tensorflow.variable": _json_to_variable, 94 } 95 try: 96 type_name = dct["__type__"] 97 return decoders[type_name](dct, ctx) 98 except KeyError as exc: 99 raise DeserializationError() from exc
102def to_json(obj: Any, ctx: Context) -> dict: 103 """ 104 Serializes a tensorflow object into JSON by cases. See the README for the 105 precise list of supported types. The return dict has the following 106 structure: 107 108 - `tf.RaggedTensor`: Not supported. 109 110 - `tf.SparseTensor`: 111 112 ```py 113 { 114 "__type__": "tensorflow.sparse_tensor", 115 "__version__": 2, 116 "indices": {...}, 117 "values": {...}, 118 "shape": {...}, 119 } 120 ``` 121 122 where the first two `{...}` placeholders result in the serialization of 123 `tf.Tensor` (see below). 124 125 - other `tf.Tensor` subtypes: 126 127 ```py 128 { 129 "__type__": "tensorflow.tensor", 130 "__version__": 4, 131 "data": { 132 "__type__": "bytes", 133 ... 134 }, 135 } 136 ``` 137 138 see `turbo_broccoli.custom.bytes.to_json`. 139 140 - `tf.Variable`: 141 142 ```py 143 { 144 "__type__": "tensorflow.tensor", 145 "__version__": 3, 146 "name": <str>, 147 "value": {...}, 148 "trainable": <bool>, 149 } 150 ``` 151 152 where `{...}` is the document produced by serializing the value tensor of 153 the variable, see above. 154 155 """ 156 encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [ 157 (tf.RaggedTensor, _ragged_tensor_to_json), 158 (tf.SparseTensor, _sparse_tensor_to_json), 159 (tf.Tensor, _tensor_to_json), 160 (tf.Variable, _variable_to_json), 161 ] 162 for t, f in encoders: 163 if isinstance(obj, t): 164 return f(obj, ctx) 165 raise TypeNotSupported()
Serializes a tensorflow object into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure:
tf.RaggedTensor
: Not supported.tf.SparseTensor
:{ "__type__": "tensorflow.sparse_tensor", "__version__": 2, "indices": {...}, "values": {...}, "shape": {...}, }
where the first two
{...}
placeholders result in the serialization oftf.Tensor
(see below).other
tf.Tensor
subtypes:{ "__type__": "tensorflow.tensor", "__version__": 4, "data": { "__type__": "bytes", ... }, }
tf.Variable
:{ "__type__": "tensorflow.tensor", "__version__": 3, "name": <str>, "value": {...}, "trainable": <bool>, }
where
{...}
is the document produced by serializing the value tensor of the variable, see above.