turbo_broccoli.custom.numpy

numpy (de)serialization utilities.

Todo: Handle numpy's generic type (which supersedes the number type).

  1"""
  2numpy (de)serialization utilities.
  3
  4Todo:
  5    Handle numpy's `generic` type (which supersedes the `number` type).
  6"""
  7
  8from typing import Any, Callable, Tuple
  9
 10import joblib
 11import numpy as np
 12from safetensors import numpy as st
 13
 14from ..context import Context
 15from ..exceptions import DeserializationError, TypeNotSupported
 16
 17
 18def _json_to_dtype(dct: dict, ctx: Context) -> np.dtype:
 19    decoders = {
 20        2: _json_to_dtype_v2,
 21    }
 22    return decoders[dct["__version__"]](dct, ctx)
 23
 24
 25def _json_to_dtype_v2(dct: dict, ctx: Context) -> np.dtype:
 26    return np.lib.format.descr_to_dtype(dct["dtype"])
 27
 28
 29def _json_to_ndarray(dct: dict, ctx: Context) -> np.ndarray:
 30    ctx.raise_if_nodecode("bytes")
 31    decoders = {
 32        5: _json_to_ndarray_v5,
 33    }
 34    return decoders[dct["__version__"]](dct, ctx)
 35
 36
 37def _json_to_ndarray_v5(dct: dict, ctx: Context) -> np.ndarray:
 38    return st.load(dct["data"])["data"]
 39
 40
 41def _json_to_number(dct: dict, ctx: Context) -> np.number:
 42    decoders = {
 43        3: _json_to_number_v3,
 44    }
 45    return decoders[dct["__version__"]](dct, ctx)
 46
 47
 48def _json_to_number_v3(dct: dict, ctx: Context) -> np.number:
 49    return np.frombuffer(dct["value"], dtype=dct["dtype"])[0]
 50
 51
 52def _json_to_random_state(dct: dict, ctx: Context) -> np.number:
 53    decoders = {
 54        3: _json_to_random_state_v3,
 55    }
 56    return decoders[dct["__version__"]](dct, ctx)
 57
 58
 59def _json_to_random_state_v3(dct: dict, ctx: Context) -> np.number:
 60    return joblib.load(ctx.id_to_artifact_path(dct["data"]))
 61
 62
 63def _dtype_to_json(d: np.dtype, ctx: Context) -> dict:
 64    return {
 65        "__type__": "numpy.dtype",
 66        "__version__": 2,
 67        "dtype": np.lib.format.dtype_to_descr(d),
 68    }
 69
 70
 71def _ndarray_to_json(arr: np.ndarray, ctx: Context) -> dict:
 72    return {
 73        "__type__": "numpy.ndarray",
 74        "__version__": 5,
 75        "data": st.save({"data": arr}),
 76    }
 77
 78
 79def _number_to_json(num: np.number, ctx: Context) -> dict:
 80    return {
 81        "__type__": "numpy.number",
 82        "__version__": 3,
 83        "value": bytes(np.array(num).data),
 84        "dtype": num.dtype,
 85    }
 86
 87
 88# pylint bug?
 89
 90
 91def _random_state_to_json(obj: np.random.RandomState, ctx: Context) -> dict:
 92    path, name = ctx.new_artifact_path()
 93    with path.open(mode="wb") as fp:
 94        joblib.dump(obj, fp)
 95    return {
 96        "__type__": "numpy.random_state",
 97        "__version__": 3,
 98        "data": name,
 99    }
100
101
102def from_json(dct: dict, ctx: Context) -> Any:
103    """
104    Deserializes a dict into a numpy object. See `to_json` for the
105    specification `dct` is expected to follow.
106    """
107    decoders = {
108        "numpy.ndarray": _json_to_ndarray,
109        "numpy.number": _json_to_number,
110        "numpy.dtype": _json_to_dtype,
111        "numpy.random_state": _json_to_random_state,
112    }
113    try:
114        type_name = dct["__type__"]
115        return decoders[type_name](dct, ctx)
116    except KeyError as exc:
117        raise DeserializationError() from exc
118
119
120def to_json(obj: Any, ctx: Context) -> dict:
121    """
122    Serializes a `numpy` object into JSON by cases. See the README for the
123    precise list of supported types. The return dict has the following
124    structure:
125
126    - `numpy.ndarray`: An array is processed differently depending on its size
127      and on the `TB_MAX_NBYTES` environment variable. If the array is
128      small, i.e. `arr.nbytes <= TB_MAX_NBYTES`, then it is directly
129      stored in the resulting JSON document as
130
131        ```py
132        {
133            "__type__": "numpy.ndarray",
134            "__version__": 5,
135            "data": {
136                "__type__": "bytes",
137                ...
138            }
139        }
140        ```
141
142      see `turbo_broccoli.custom.bytes.to_json`.
143
144    - `numpy.number`:
145
146        ```py
147        {
148            "__type__": "numpy.number",
149            "__version__": 3,
150            "value": <float>,
151            "dtype": {...},
152        }
153        ```
154
155        where the `dtype` document follows the specification below.
156
157    - `numpy.dtype`:
158
159        ```py
160        {
161            "__type__": "numpy.dtype",
162            "__version__": 2,
163            "dtype": <dtype_to_descr string>,
164        }
165        ```
166
167    - `numpy.random.RandomState`:
168
169        ```py
170        {
171            "__type__": "numpy.random_state",
172            "__version__": 3,
173            "data": <uuid4>,
174        }
175        ```
176
177    """
178    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
179        (np.ndarray, _ndarray_to_json),
180        (np.number, _number_to_json),
181        (np.dtype, _dtype_to_json),
182        (
183            np.random.RandomState,
184            _random_state_to_json,
185        ),
186    ]
187    for t, f in encoders:
188        if isinstance(obj, t):
189            return f(obj, ctx)
190    raise TypeNotSupported()
def from_json(dct: dict, ctx: turbo_broccoli.context.Context) -> Any:
103def from_json(dct: dict, ctx: Context) -> Any:
104    """
105    Deserializes a dict into a numpy object. See `to_json` for the
106    specification `dct` is expected to follow.
107    """
108    decoders = {
109        "numpy.ndarray": _json_to_ndarray,
110        "numpy.number": _json_to_number,
111        "numpy.dtype": _json_to_dtype,
112        "numpy.random_state": _json_to_random_state,
113    }
114    try:
115        type_name = dct["__type__"]
116        return decoders[type_name](dct, ctx)
117    except KeyError as exc:
118        raise DeserializationError() from exc

Deserializes a dict into a numpy object. See to_json for the specification dct is expected to follow.

def to_json(obj: Any, ctx: turbo_broccoli.context.Context) -> dict:
121def to_json(obj: Any, ctx: Context) -> dict:
122    """
123    Serializes a `numpy` object into JSON by cases. See the README for the
124    precise list of supported types. The return dict has the following
125    structure:
126
127    - `numpy.ndarray`: An array is processed differently depending on its size
128      and on the `TB_MAX_NBYTES` environment variable. If the array is
129      small, i.e. `arr.nbytes <= TB_MAX_NBYTES`, then it is directly
130      stored in the resulting JSON document as
131
132        ```py
133        {
134            "__type__": "numpy.ndarray",
135            "__version__": 5,
136            "data": {
137                "__type__": "bytes",
138                ...
139            }
140        }
141        ```
142
143      see `turbo_broccoli.custom.bytes.to_json`.
144
145    - `numpy.number`:
146
147        ```py
148        {
149            "__type__": "numpy.number",
150            "__version__": 3,
151            "value": <float>,
152            "dtype": {...},
153        }
154        ```
155
156        where the `dtype` document follows the specification below.
157
158    - `numpy.dtype`:
159
160        ```py
161        {
162            "__type__": "numpy.dtype",
163            "__version__": 2,
164            "dtype": <dtype_to_descr string>,
165        }
166        ```
167
168    - `numpy.random.RandomState`:
169
170        ```py
171        {
172            "__type__": "numpy.random_state",
173            "__version__": 3,
174            "data": <uuid4>,
175        }
176        ```
177
178    """
179    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
180        (np.ndarray, _ndarray_to_json),
181        (np.number, _number_to_json),
182        (np.dtype, _dtype_to_json),
183        (
184            np.random.RandomState,
185            _random_state_to_json,
186        ),
187    ]
188    for t, f in encoders:
189        if isinstance(obj, t):
190            return f(obj, ctx)
191    raise TypeNotSupported()

Serializes a numpy object into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure:

  • numpy.ndarray: An array is processed differently depending on its size and on the TB_MAX_NBYTES environment variable. If the array is small, i.e. arr.nbytes <= TB_MAX_NBYTES, then it is directly stored in the resulting JSON document as

    {
        "__type__": "numpy.ndarray",
        "__version__": 5,
        "data": {
            "__type__": "bytes",
            ...
        }
    }
    

    see turbo_broccoli.custom.bytes.to_json.

  • numpy.number:

    {
        "__type__": "numpy.number",
        "__version__": 3,
        "value": <float>,
        "dtype": {...},
    }
    

    where the dtype document follows the specification below.

  • numpy.dtype:

    {
        "__type__": "numpy.dtype",
        "__version__": 2,
        "dtype": <dtype_to_descr string>,
    }
    
  • numpy.random.RandomState:

    {
        "__type__": "numpy.random_state",
        "__version__": 3,
        "data": <uuid4>,
    }