turbo_broccoli.custom.pytorch
Pytorch (de)serialization utilities.
1"""Pytorch (de)serialization utilities.""" 2 3from typing import Any, Callable, Tuple 4 5import safetensors.torch as st 6from torch import Tensor 7from torch.nn import Module 8from torch.utils.data import ConcatDataset, StackDataset, Subset, TensorDataset 9 10from ..context import Context 11from ..exceptions import DeserializationError, TypeNotSupported 12 13 14def _concatdataset_to_json(obj: ConcatDataset, ctx: Context) -> dict: 15 return { 16 "__type__": "pytorch.concatdataset", 17 "__version__": 1, 18 "datasets": obj.datasets, 19 } 20 21 22def _json_to_concatdataset(dct: dict, ctx: Context) -> ConcatDataset: 23 decoders = {1: _json_to_concatdataset_v1} 24 return decoders[dct["__version__"]](dct, ctx) 25 26 27def _json_to_module(dct: dict, ctx: Context) -> Module: 28 ctx.raise_if_nodecode("bytes") 29 decoders = { 30 3: _json_to_module_v3, 31 } 32 return decoders[dct["__version__"]](dct, ctx) 33 34 35def _json_to_module_v3(dct: dict, ctx: Context) -> Module: 36 parts = dct["__type__"].split(".") 37 type_name = ".".join(parts[2:]) # remove "pytorch.module." prefix 38 module: Module = ctx.pytorch_module_types[type_name]() 39 state = st.load(dct["state"]) 40 module.load_state_dict(state) 41 return module 42 43 44def _json_to_concatdataset_v1(dct: dict, ctx: Context) -> ConcatDataset: 45 return ConcatDataset(dct["datasets"]) 46 47 48def _json_to_stackdataset(dct: dict, ctx: Context) -> StackDataset: 49 decoders = {1: _json_to_stackdataset_v1} 50 return decoders[dct["__version__"]](dct, ctx) 51 52 53def _json_to_stackdataset_v1(dct: dict, ctx: Context) -> StackDataset: 54 d = dct["datasets"] 55 if isinstance(d, dict): 56 return StackDataset(**d) 57 return StackDataset(*d) 58 59 60def _json_to_subset(dct: dict, ctx: Context) -> Subset: 61 decoders = {1: _json_to_subset_v1} 62 return decoders[dct["__version__"]](dct, ctx) 63 64 65def _json_to_subset_v1(dct: dict, ctx: Context) -> Subset: 66 return Subset(dct["dataset"], dct["indices"]) 67 68 69def _json_to_tensor(dct: dict, ctx: Context) -> Tensor: 70 ctx.raise_if_nodecode("bytes") 71 decoders = { 72 3: _json_to_tensor_v3, 73 } 74 return decoders[dct["__version__"]](dct, ctx) 75 76 77def _json_to_tensor_v3(dct: dict, ctx: Context) -> Tensor: 78 data = dct["data"] 79 return Tensor() if data is None else st.load(data)["data"] 80 81 82def _json_to_tensordataset(dct: dict, ctx: Context) -> TensorDataset: 83 decoders = {1: _json_to_tensordataset_v1} 84 return decoders[dct["__version__"]](dct, ctx) 85 86 87def _json_to_tensordataset_v1(dct: dict, ctx: Context) -> TensorDataset: 88 return TensorDataset(*dct["tensors"]) 89 90 91def _module_to_json(module: Module, ctx: Context) -> dict: 92 return { 93 "__type__": "pytorch.module." + module.__class__.__name__, 94 "__version__": 3, 95 "state": st.save(module.state_dict()), 96 } 97 98 99def _stackdataset_to_json(obj: StackDataset, ctx: Context) -> dict: 100 return { 101 "__type__": "pytorch.stackdataset", 102 "__version__": 1, 103 "datasets": obj.datasets, 104 } 105 106 107def _subset_to_json(obj: Subset, ctx: Context) -> dict: 108 return { 109 "__type__": "pytorch.subset", 110 "__version__": 1, 111 "dataset": obj.dataset, 112 "indices": obj.indices, 113 } 114 115 116def _tensor_to_json(tens: Tensor, ctx: Context) -> dict: 117 x = tens.detach().cpu().contiguous() 118 return { 119 "__type__": "pytorch.tensor", 120 "__version__": 3, 121 "data": st.save({"data": x}) if x.numel() > 0 else None, 122 } 123 124 125def _tensordataset_to_json(obj: TensorDataset, ctx: Context) -> dict: 126 return { 127 "__type__": "pytorch.tensordataset", 128 "__version__": 1, 129 "tensors": obj.tensors, 130 } 131 132 133def from_json(dct: dict, ctx: Context) -> Any: 134 decoders = { 135 "pytorch.concatdataset": _json_to_concatdataset, 136 "pytorch.stackdataset": _json_to_stackdataset, 137 "pytorch.subset": _json_to_subset, 138 "pytorch.tensor": _json_to_tensor, 139 "pytorch.tensordataset": _json_to_tensordataset, 140 } 141 try: 142 type_name = dct["__type__"] 143 if type_name.startswith("pytorch.module."): 144 return _json_to_module(dct, ctx) 145 return decoders[type_name](dct, ctx) 146 except KeyError as exc: 147 raise DeserializationError() from exc 148 149 150def to_json(obj: Any, ctx: Context) -> dict: 151 """ 152 Serializes a tensor into JSON by cases. See the README for the precise list 153 of supported types. The return dict has the following structure: 154 155 - Tensor: 156 157 ```py 158 { 159 "__type__": "pytorch.tensor", 160 "__version__": 3, 161 "data": { 162 "__type__": "bytes", 163 ... 164 }, 165 } 166 ``` 167 168 see `turbo_broccoli.custom.bytes.to_json`. 169 170 - Module: 171 172 ```py 173 { 174 "__type__": "pytorch.module.<class name>", 175 "__version__": 3, 176 "state": { 177 "__type__": "bytes", 178 ... 179 }, 180 } 181 ``` 182 183 see `turbo_broccoli.custom.bytes.to_json`. 184 185 """ 186 encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [ 187 (Module, _module_to_json), 188 (Tensor, _tensor_to_json), 189 (ConcatDataset, _concatdataset_to_json), 190 (StackDataset, _stackdataset_to_json), 191 (Subset, _subset_to_json), 192 (TensorDataset, _tensordataset_to_json), 193 ] 194 for t, f in encoders: 195 if isinstance(obj, t): 196 return f(obj, ctx) 197 raise TypeNotSupported()
134def from_json(dct: dict, ctx: Context) -> Any: 135 decoders = { 136 "pytorch.concatdataset": _json_to_concatdataset, 137 "pytorch.stackdataset": _json_to_stackdataset, 138 "pytorch.subset": _json_to_subset, 139 "pytorch.tensor": _json_to_tensor, 140 "pytorch.tensordataset": _json_to_tensordataset, 141 } 142 try: 143 type_name = dct["__type__"] 144 if type_name.startswith("pytorch.module."): 145 return _json_to_module(dct, ctx) 146 return decoders[type_name](dct, ctx) 147 except KeyError as exc: 148 raise DeserializationError() from exc
151def to_json(obj: Any, ctx: Context) -> dict: 152 """ 153 Serializes a tensor into JSON by cases. See the README for the precise list 154 of supported types. The return dict has the following structure: 155 156 - Tensor: 157 158 ```py 159 { 160 "__type__": "pytorch.tensor", 161 "__version__": 3, 162 "data": { 163 "__type__": "bytes", 164 ... 165 }, 166 } 167 ``` 168 169 see `turbo_broccoli.custom.bytes.to_json`. 170 171 - Module: 172 173 ```py 174 { 175 "__type__": "pytorch.module.<class name>", 176 "__version__": 3, 177 "state": { 178 "__type__": "bytes", 179 ... 180 }, 181 } 182 ``` 183 184 see `turbo_broccoli.custom.bytes.to_json`. 185 186 """ 187 encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [ 188 (Module, _module_to_json), 189 (Tensor, _tensor_to_json), 190 (ConcatDataset, _concatdataset_to_json), 191 (StackDataset, _stackdataset_to_json), 192 (Subset, _subset_to_json), 193 (TensorDataset, _tensordataset_to_json), 194 ] 195 for t, f in encoders: 196 if isinstance(obj, t): 197 return f(obj, ctx) 198 raise TypeNotSupported()
Serializes a tensor into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure:
Tensor:
{ "__type__": "pytorch.tensor", "__version__": 3, "data": { "__type__": "bytes", ... }, }
Module:
{ "__type__": "pytorch.module.<class name>", "__version__": 3, "state": { "__type__": "bytes", ... }, }