turbo_broccoli.native

"Native" saving utilities:

  1"""
  2"Native" saving utilities:
  3* `turbo_broccoli.native.save` takes a serializable/dumpable object and
  4  a path, and uses the file extension to choose the correct way to save the
  5  object;
  6* `turbo_broccoli.native.load` does the opposite.
  7"""
  8
  9from functools import partial
 10from pathlib import Path
 11from typing import Any, Callable
 12
 13try:
 14    import safetensors
 15
 16    HAS_SAFETENSORS = True
 17except ModuleNotFoundError:
 18    HAS_SAFETENSORS = False
 19
 20from .custom import (
 21    HAS_KERAS,
 22    HAS_NUMPY,
 23    HAS_PANDAS,
 24    HAS_PYTORCH,
 25    HAS_TENSORFLOW,
 26)
 27from .turbo_broccoli import load_json, save_json
 28
 29
 30def _is_dict_of(obj: Any, value_type: type, key_type: type = str) -> bool:
 31    """Returns true if `obj` is a `dict[key_type, value_type]`"""
 32    return (
 33        isinstance(obj, dict)
 34        and all(isinstance(k, key_type) for k in obj.keys())
 35        and all(isinstance(v, value_type) for v in obj.values())
 36    )
 37
 38
 39def _load_csv(path: str | Path, **kwargs) -> Any:
 40    if not HAS_PANDAS:
 41        _raise_package_not_installed("pandas", "csv")
 42    import pandas as pd
 43
 44    df = pd.read_csv(path, **kwargs)
 45    if "Unnamed: 0" in df.columns:
 46        df.drop(["Unnamed: 0"], axis=1, inplace=True)
 47    return df
 48
 49
 50def _load_keras(path: str | Path, **kwargs) -> Any:
 51    if not HAS_KERAS:
 52        _raise_package_not_installed("keras", "keras")
 53
 54    import keras
 55
 56    return keras.saving.load_model(path, **kwargs)
 57
 58
 59def _load_np(path: str | Path, **kwargs) -> Any:
 60    if not HAS_NUMPY:
 61        _raise_package_not_installed("numpy", ".npy/.npz")
 62    import numpy as np
 63
 64    return np.load(path, **kwargs)
 65
 66
 67def _load_pq(path: str | Path, **kwargs) -> Any:
 68    if not HAS_PANDAS:
 69        _raise_package_not_installed("pandas", ".parquet/.pq")
 70    import pandas as pd
 71
 72    df = pd.read_parquet(path, **kwargs)
 73    return df
 74
 75
 76def _load_pt(path: str | Path, **kwargs) -> Any:
 77    if not HAS_PYTORCH:
 78        _raise_package_not_installed("torch", "pt")
 79    import torch
 80
 81    return torch.load(path, **kwargs)
 82
 83
 84def _load_st(path: str | Path, **kwargs) -> Any:
 85    if not HAS_SAFETENSORS:
 86        _raise_package_not_installed("safetensors", ".safetensors/.st")
 87    from safetensors import numpy as st
 88
 89    return st.load_file(path, **kwargs)
 90
 91
 92def _raise_package_not_installed(package_name: str, extension: str):
 93    """
 94    Raises a `RuntimeError` with a templated error message
 95
 96    Args:
 97        package_name (str): e.g. "numpy"
 98        extension (str): e.g. "npy"
 99    """
100    if extension[0] != ".":
101        extension = "." + extension
102    raise RuntimeError(
103        f"Cannot create or load `{extension}` file because {package_name} is "
104        f"not installed. You can install {package_name} by running "
105        f"python3 -m pip install {package_name}"
106    )
107
108
109def _raise_wrong_type(path: str | Path, obj_needs_to_be_a: str):
110    """
111    Raises a `TypeError` with a templated error message
112
113    Args:
114        path (str | Path): Path where the file should have been saved
115        extension (str): "pandas DataFrame or Series"
116    """
117    raise TypeError(
118        f"Could not save object to '{path}': object needs to be a "
119        + obj_needs_to_be_a
120    )
121
122
123def _save_csv(obj: Any, path: str | Path, **kwargs) -> None:
124    if not HAS_PANDAS:
125        _raise_package_not_installed("pandas", "csv")
126    import pandas as pd
127
128    if not isinstance(obj, (pd.DataFrame, pd.Series)):
129        _raise_wrong_type(path, "pandas DataFrame or Series")
130    obj.to_csv(path, **kwargs)
131
132
133def _save_keras(obj: Any, path: str | Path, **kwargs) -> None:
134    if not HAS_KERAS:
135        _raise_package_not_installed("keras", "keras")
136
137    import keras
138
139    if not isinstance(obj, keras.Model):
140        _raise_wrong_type(path, "keras model")
141    keras.saving.save_model(obj, path, **kwargs)
142
143
144def _save_npy(obj: Any, path: str | Path, **kwargs) -> None:
145    if not HAS_NUMPY:
146        _raise_package_not_installed("numpy", "npy")
147    import numpy as np
148
149    if not isinstance(obj, np.ndarray):
150        _raise_wrong_type(path, "numpy array")
151    np.save(str(path), obj, **kwargs)
152
153
154def _save_npz(obj: Any, path: str | Path, **kwargs) -> None:
155    if not HAS_NUMPY:
156        _raise_package_not_installed("numpy", "npz")
157    import numpy as np
158
159    if not _is_dict_of(obj, np.ndarray):
160        _raise_wrong_type(path, "dict of numpy arrays")
161    np.savez(str(path), **obj, **kwargs)
162
163
164def _save_pq(obj: Any, path: str | Path, **kwargs) -> None:
165    if not HAS_PANDAS:
166        _raise_package_not_installed("pandas", ".parquet/.pq")
167    import pandas as pd
168
169    if not isinstance(obj, pd.DataFrame):
170        _raise_wrong_type(path, "pandas DataFrame")
171    obj.to_parquet(path, **kwargs)
172
173
174def _save_pt(obj: Any, path: str | Path, **kwargs) -> None:
175    if not HAS_PYTORCH:
176        _raise_package_not_installed("torch", "pt")
177    import torch
178
179    if not (isinstance(obj, torch.Tensor) or _is_dict_of(obj, torch.Tensor)):
180        _raise_wrong_type(path, "torch tensor or a dict of torch tensors")
181    torch.save(obj, path, **kwargs)
182
183
184def _save_st(obj: Any, path: str | Path, **kwargs) -> None:
185    if not HAS_SAFETENSORS:
186        _raise_package_not_installed("safetensors", ".safetensors/.st")
187    import safetensors
188
189    if HAS_NUMPY:
190        import numpy as np
191
192        if _is_dict_of(obj, np.ndarray):
193            safetensors.numpy.save_file(obj, str(path), **kwargs)
194            return
195
196    if HAS_TENSORFLOW:
197        import tensorflow as tf
198
199        if _is_dict_of(obj, tf.Tensor):
200            safetensors.tensorflow.save_file(obj, str(path), **kwargs)
201            return
202
203    if HAS_PYTORCH:
204        import torch
205
206        if _is_dict_of(obj, torch.Tensor):
207            safetensors.torch.save_file(obj, str(path), **kwargs)
208            return
209
210    raise _raise_wrong_type(
211        path,
212        "dict of numpy arrays, a dict of tensorflow tensors, or a dict of "
213        "pytorch tensors",
214    )
215
216
217def load(path: str | Path, **kwargs) -> Any:
218    """
219    Loads an object from a file using format-specific (or "native") methods.
220    See `turbo_broccoli.native.save` for the list of supported file extensions.
221
222    Warning:
223        Safetensors files (`.st` or `.safetensors`) will be loaded as dicts of
224        numpy arrays even of the object was originally a dict of e.g. torch
225        tensors.
226    """
227    extension = Path(path).suffix
228    methods: dict[str, Callable[[str | Path], Any]] = {
229        ".csv": _load_csv,
230        ".h5": _load_keras,
231        ".keras": _load_keras,
232        ".npy": _load_np,
233        ".npz": _load_np,
234        ".parquet": _load_pq,
235        ".pq": _load_pq,
236        ".pt": _load_pt,
237        ".st": _load_st,
238        ".tf": _load_keras,
239    }
240    method: Callable = methods.get(extension, load_json)
241    return method(path, **kwargs)
242
243
244def save(obj: Any, path: str | Path, **kwargs) -> None:
245    """
246    Saves an object using the file extension of `path` to determine the
247    serialization/dumping method:
248
249    * `.csv`:
250      [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html)
251      or
252      [`pandas.Series.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.Series.to_csv.html)
253    * `.h5`:
254      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
255      with `save_format="h5"`
256    * `.keras`:
257      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
258      with `save_format="keras"`
259    * `.npy`:
260        [`numpy.save`](https://numpy.org/doc/stable/reference/generated/numpy.save.html)
261    * `.npz`:
262        [`numpy.savez`](https://numpy.org/doc/stable/reference/generated/numpy.savez.html)
263    * `.pq`, `.parquet`:
264      [`pandas.DataFrame.to_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_parquet.html)
265    * `.pt`:
266      [`torch.save`](https://pytorch.org/docs/stable/generated/torch.save.html)
267    * `.safetensors`, `.st`: (for numpy arrays, pytorch tensors and tensorflow tensors)
268      [safetensors](https://huggingface.co/docs/safetensors/index)
269    * `.tf`:
270      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
271      with `save_format="tf"`
272    * `.json` and anything else: just forwarded to `turbo_broccoli.save_json`
273
274    Args:
275        obj (Any):
276        path (str | Path):
277        kwargs: Passed to the serialization method
278    """
279    extension = Path(path).suffix
280    methods: dict[str, Callable[[Any, str | Path], None]] = {
281        ".csv": _save_csv,
282        ".h5": partial(_save_keras, save_format="h5"),
283        ".keras": partial(_save_keras, save_format="keras"),
284        ".npy": _save_npy,
285        ".npz": _save_npz,
286        ".parquet": _save_pq,
287        ".pq": _save_pq,
288        ".pt": _save_pt,
289        ".st": _save_st,
290        ".tf": partial(_save_keras, save_format="tf"),
291    }
292    method = methods.get(extension, save_json)
293    method(obj, path, **kwargs)
def load(path: str | pathlib.Path, **kwargs) -> Any:
218def load(path: str | Path, **kwargs) -> Any:
219    """
220    Loads an object from a file using format-specific (or "native") methods.
221    See `turbo_broccoli.native.save` for the list of supported file extensions.
222
223    Warning:
224        Safetensors files (`.st` or `.safetensors`) will be loaded as dicts of
225        numpy arrays even of the object was originally a dict of e.g. torch
226        tensors.
227    """
228    extension = Path(path).suffix
229    methods: dict[str, Callable[[str | Path], Any]] = {
230        ".csv": _load_csv,
231        ".h5": _load_keras,
232        ".keras": _load_keras,
233        ".npy": _load_np,
234        ".npz": _load_np,
235        ".parquet": _load_pq,
236        ".pq": _load_pq,
237        ".pt": _load_pt,
238        ".st": _load_st,
239        ".tf": _load_keras,
240    }
241    method: Callable = methods.get(extension, load_json)
242    return method(path, **kwargs)

Loads an object from a file using format-specific (or "native") methods. See turbo_broccoli.native.save for the list of supported file extensions.

Warning: Safetensors files (.st or .safetensors) will be loaded as dicts of numpy arrays even of the object was originally a dict of e.g. torch tensors.

def save(obj: Any, path: str | pathlib.Path, **kwargs) -> None:
245def save(obj: Any, path: str | Path, **kwargs) -> None:
246    """
247    Saves an object using the file extension of `path` to determine the
248    serialization/dumping method:
249
250    * `.csv`:
251      [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html)
252      or
253      [`pandas.Series.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.Series.to_csv.html)
254    * `.h5`:
255      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
256      with `save_format="h5"`
257    * `.keras`:
258      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
259      with `save_format="keras"`
260    * `.npy`:
261        [`numpy.save`](https://numpy.org/doc/stable/reference/generated/numpy.save.html)
262    * `.npz`:
263        [`numpy.savez`](https://numpy.org/doc/stable/reference/generated/numpy.savez.html)
264    * `.pq`, `.parquet`:
265      [`pandas.DataFrame.to_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_parquet.html)
266    * `.pt`:
267      [`torch.save`](https://pytorch.org/docs/stable/generated/torch.save.html)
268    * `.safetensors`, `.st`: (for numpy arrays, pytorch tensors and tensorflow tensors)
269      [safetensors](https://huggingface.co/docs/safetensors/index)
270    * `.tf`:
271      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
272      with `save_format="tf"`
273    * `.json` and anything else: just forwarded to `turbo_broccoli.save_json`
274
275    Args:
276        obj (Any):
277        path (str | Path):
278        kwargs: Passed to the serialization method
279    """
280    extension = Path(path).suffix
281    methods: dict[str, Callable[[Any, str | Path], None]] = {
282        ".csv": _save_csv,
283        ".h5": partial(_save_keras, save_format="h5"),
284        ".keras": partial(_save_keras, save_format="keras"),
285        ".npy": _save_npy,
286        ".npz": _save_npz,
287        ".parquet": _save_pq,
288        ".pq": _save_pq,
289        ".pt": _save_pt,
290        ".st": _save_st,
291        ".tf": partial(_save_keras, save_format="tf"),
292    }
293    method = methods.get(extension, save_json)
294    method(obj, path, **kwargs)

Saves an object using the file extension of path to determine the serialization/dumping method:

Args: obj (Any): path (str | Path): kwargs: Passed to the serialization method