turbo_broccoli.custom.pytorch

Pytorch (de)serialization utilities.

  1"""Pytorch (de)serialization utilities."""
  2
  3from typing import Any, Callable, Tuple
  4
  5import safetensors.torch as st
  6from torch import Tensor
  7from torch.nn import Module
  8from torch.utils.data import ConcatDataset, StackDataset, Subset, TensorDataset
  9
 10from ..context import Context
 11from ..exceptions import DeserializationError, TypeNotSupported
 12
 13
 14def _concatdataset_to_json(obj: ConcatDataset, ctx: Context) -> dict:
 15    return {
 16        "__type__": "pytorch.concatdataset",
 17        "__version__": 1,
 18        "datasets": obj.datasets,
 19    }
 20
 21
 22def _json_to_concatdataset(dct: dict, ctx: Context) -> ConcatDataset:
 23    decoders = {1: _json_to_concatdataset_v1}
 24    return decoders[dct["__version__"]](dct, ctx)
 25
 26
 27def _json_to_module(dct: dict, ctx: Context) -> Module:
 28    ctx.raise_if_nodecode("bytes")
 29    decoders = {
 30        3: _json_to_module_v3,
 31    }
 32    return decoders[dct["__version__"]](dct, ctx)
 33
 34
 35def _json_to_module_v3(dct: dict, ctx: Context) -> Module:
 36    parts = dct["__type__"].split(".")
 37    type_name = ".".join(parts[2:])  # remove "pytorch.module." prefix
 38    module: Module = ctx.pytorch_module_types[type_name]()
 39    state = st.load(dct["state"])
 40    module.load_state_dict(state)
 41    return module
 42
 43
 44def _json_to_concatdataset_v1(dct: dict, ctx: Context) -> ConcatDataset:
 45    return ConcatDataset(dct["datasets"])
 46
 47
 48def _json_to_stackdataset(dct: dict, ctx: Context) -> StackDataset:
 49    decoders = {1: _json_to_stackdataset_v1}
 50    return decoders[dct["__version__"]](dct, ctx)
 51
 52
 53def _json_to_stackdataset_v1(dct: dict, ctx: Context) -> StackDataset:
 54    d = dct["datasets"]
 55    if isinstance(d, dict):
 56        return StackDataset(**d)
 57    return StackDataset(*d)
 58
 59
 60def _json_to_subset(dct: dict, ctx: Context) -> Subset:
 61    decoders = {1: _json_to_subset_v1}
 62    return decoders[dct["__version__"]](dct, ctx)
 63
 64
 65def _json_to_subset_v1(dct: dict, ctx: Context) -> Subset:
 66    return Subset(dct["dataset"], dct["indices"])
 67
 68
 69def _json_to_tensor(dct: dict, ctx: Context) -> Tensor:
 70    ctx.raise_if_nodecode("bytes")
 71    decoders = {
 72        3: _json_to_tensor_v3,
 73    }
 74    return decoders[dct["__version__"]](dct, ctx)
 75
 76
 77def _json_to_tensor_v3(dct: dict, ctx: Context) -> Tensor:
 78    data = dct["data"]
 79    return Tensor() if data is None else st.load(data)["data"]
 80
 81
 82def _json_to_tensordataset(dct: dict, ctx: Context) -> TensorDataset:
 83    decoders = {1: _json_to_tensordataset_v1}
 84    return decoders[dct["__version__"]](dct, ctx)
 85
 86
 87def _json_to_tensordataset_v1(dct: dict, ctx: Context) -> TensorDataset:
 88    return TensorDataset(*dct["tensors"])
 89
 90
 91def _module_to_json(module: Module, ctx: Context) -> dict:
 92    return {
 93        "__type__": "pytorch.module." + module.__class__.__name__,
 94        "__version__": 3,
 95        "state": st.save(module.state_dict()),
 96    }
 97
 98
 99def _stackdataset_to_json(obj: StackDataset, ctx: Context) -> dict:
100    return {
101        "__type__": "pytorch.stackdataset",
102        "__version__": 1,
103        "datasets": obj.datasets,
104    }
105
106
107def _subset_to_json(obj: Subset, ctx: Context) -> dict:
108    return {
109        "__type__": "pytorch.subset",
110        "__version__": 1,
111        "dataset": obj.dataset,
112        "indices": obj.indices,
113    }
114
115
116def _tensor_to_json(tens: Tensor, ctx: Context) -> dict:
117    x = tens.detach().cpu().contiguous()
118    return {
119        "__type__": "pytorch.tensor",
120        "__version__": 3,
121        "data": st.save({"data": x}) if x.numel() > 0 else None,
122    }
123
124
125def _tensordataset_to_json(obj: TensorDataset, ctx: Context) -> dict:
126    return {
127        "__type__": "pytorch.tensordataset",
128        "__version__": 1,
129        "tensors": obj.tensors,
130    }
131
132
133def from_json(dct: dict, ctx: Context) -> Any:
134    decoders = {
135        "pytorch.concatdataset": _json_to_concatdataset,
136        "pytorch.stackdataset": _json_to_stackdataset,
137        "pytorch.subset": _json_to_subset,
138        "pytorch.tensor": _json_to_tensor,
139        "pytorch.tensordataset": _json_to_tensordataset,
140    }
141    try:
142        type_name = dct["__type__"]
143        if type_name.startswith("pytorch.module."):
144            return _json_to_module(dct, ctx)
145        return decoders[type_name](dct, ctx)
146    except KeyError as exc:
147        raise DeserializationError() from exc
148
149
150def to_json(obj: Any, ctx: Context) -> dict:
151    """
152    Serializes a tensor into JSON by cases. See the README for the precise list
153    of supported types. The return dict has the following structure:
154
155    - Tensor:
156
157        ```py
158        {
159            "__type__": "pytorch.tensor",
160            "__version__": 3,
161            "data": {
162                "__type__": "bytes",
163                ...
164            },
165        }
166        ```
167
168      see `turbo_broccoli.custom.bytes.to_json`.
169
170    - Module:
171
172        ```py
173        {
174            "__type__": "pytorch.module.<class name>",
175            "__version__": 3,
176            "state": {
177                "__type__": "bytes",
178                ...
179            },
180        }
181        ```
182
183      see `turbo_broccoli.custom.bytes.to_json`.
184
185    """
186    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
187        (Module, _module_to_json),
188        (Tensor, _tensor_to_json),
189        (ConcatDataset, _concatdataset_to_json),
190        (StackDataset, _stackdataset_to_json),
191        (Subset, _subset_to_json),
192        (TensorDataset, _tensordataset_to_json),
193    ]
194    for t, f in encoders:
195        if isinstance(obj, t):
196            return f(obj, ctx)
197    raise TypeNotSupported()
def from_json(dct: dict, ctx: turbo_broccoli.context.Context) -> Any:
134def from_json(dct: dict, ctx: Context) -> Any:
135    decoders = {
136        "pytorch.concatdataset": _json_to_concatdataset,
137        "pytorch.stackdataset": _json_to_stackdataset,
138        "pytorch.subset": _json_to_subset,
139        "pytorch.tensor": _json_to_tensor,
140        "pytorch.tensordataset": _json_to_tensordataset,
141    }
142    try:
143        type_name = dct["__type__"]
144        if type_name.startswith("pytorch.module."):
145            return _json_to_module(dct, ctx)
146        return decoders[type_name](dct, ctx)
147    except KeyError as exc:
148        raise DeserializationError() from exc
def to_json(obj: Any, ctx: turbo_broccoli.context.Context) -> dict:
151def to_json(obj: Any, ctx: Context) -> dict:
152    """
153    Serializes a tensor into JSON by cases. See the README for the precise list
154    of supported types. The return dict has the following structure:
155
156    - Tensor:
157
158        ```py
159        {
160            "__type__": "pytorch.tensor",
161            "__version__": 3,
162            "data": {
163                "__type__": "bytes",
164                ...
165            },
166        }
167        ```
168
169      see `turbo_broccoli.custom.bytes.to_json`.
170
171    - Module:
172
173        ```py
174        {
175            "__type__": "pytorch.module.<class name>",
176            "__version__": 3,
177            "state": {
178                "__type__": "bytes",
179                ...
180            },
181        }
182        ```
183
184      see `turbo_broccoli.custom.bytes.to_json`.
185
186    """
187    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
188        (Module, _module_to_json),
189        (Tensor, _tensor_to_json),
190        (ConcatDataset, _concatdataset_to_json),
191        (StackDataset, _stackdataset_to_json),
192        (Subset, _subset_to_json),
193        (TensorDataset, _tensordataset_to_json),
194    ]
195    for t, f in encoders:
196        if isinstance(obj, t):
197            return f(obj, ctx)
198    raise TypeNotSupported()

Serializes a tensor into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure: