turbo_broccoli.native
"Native" saving utilities:
turbo_broccoli.native.save
takes a serializable/dumpable object and a path, and uses the file extension to choose the correct way to save the object;turbo_broccoli.native.load
does the opposite.
1""" 2"Native" saving utilities: 3* `turbo_broccoli.native.save` takes a serializable/dumpable object and 4 a path, and uses the file extension to choose the correct way to save the 5 object; 6* `turbo_broccoli.native.load` does the opposite. 7""" 8 9from functools import partial 10from pathlib import Path 11from typing import Any, Callable 12 13try: 14 import safetensors 15 16 HAS_SAFETENSORS = True 17except ModuleNotFoundError: 18 HAS_SAFETENSORS = False 19 20from .custom import ( 21 HAS_KERAS, 22 HAS_NUMPY, 23 HAS_PANDAS, 24 HAS_PYTORCH, 25 HAS_TENSORFLOW, 26) 27from .turbo_broccoli import load_json, save_json 28 29 30def _is_dict_of(obj: Any, value_type: type, key_type: type = str) -> bool: 31 """Returns true if `obj` is a `dict[key_type, value_type]`""" 32 return ( 33 isinstance(obj, dict) 34 and all(isinstance(k, key_type) for k in obj.keys()) 35 and all(isinstance(v, value_type) for v in obj.values()) 36 ) 37 38 39def _load_csv(path: str | Path, **kwargs) -> Any: 40 if not HAS_PANDAS: 41 _raise_package_not_installed("pandas", "csv") 42 import pandas as pd 43 44 df = pd.read_csv(path, **kwargs) 45 if "Unnamed: 0" in df.columns: 46 df.drop(["Unnamed: 0"], axis=1, inplace=True) 47 return df 48 49 50def _load_keras(path: str | Path, **kwargs) -> Any: 51 if not HAS_KERAS: 52 _raise_package_not_installed("keras", "keras") 53 54 import keras 55 56 return keras.saving.load_model(path, **kwargs) 57 58 59def _load_np(path: str | Path, **kwargs) -> Any: 60 if not HAS_NUMPY: 61 _raise_package_not_installed("numpy", ".npy/.npz") 62 import numpy as np 63 64 return np.load(path, **kwargs) 65 66 67def _load_pq(path: str | Path, **kwargs) -> Any: 68 if not HAS_PANDAS: 69 _raise_package_not_installed("pandas", ".parquet/.pq") 70 import pandas as pd 71 72 df = pd.read_parquet(path, **kwargs) 73 return df 74 75 76def _load_pt(path: str | Path, **kwargs) -> Any: 77 if not HAS_PYTORCH: 78 _raise_package_not_installed("torch", "pt") 79 import torch 80 81 return torch.load(path, **kwargs) 82 83 84def _load_st(path: str | Path, **kwargs) -> Any: 85 if not HAS_SAFETENSORS: 86 _raise_package_not_installed("safetensors", ".safetensors/.st") 87 from safetensors import numpy as st 88 89 return st.load_file(path, **kwargs) 90 91 92def _raise_package_not_installed(package_name: str, extension: str): 93 """ 94 Raises a `RuntimeError` with a templated error message 95 96 Args: 97 package_name (str): e.g. "numpy" 98 extension (str): e.g. "npy" 99 """ 100 if extension[0] != ".": 101 extension = "." + extension 102 raise RuntimeError( 103 f"Cannot create or load `{extension}` file because {package_name} is " 104 f"not installed. You can install {package_name} by running " 105 f"python3 -m pip install {package_name}" 106 ) 107 108 109def _raise_wrong_type(path: str | Path, obj_needs_to_be_a: str): 110 """ 111 Raises a `TypeError` with a templated error message 112 113 Args: 114 path (str | Path): Path where the file should have been saved 115 extension (str): "pandas DataFrame or Series" 116 """ 117 raise TypeError( 118 f"Could not save object to '{path}': object needs to be a " 119 + obj_needs_to_be_a 120 ) 121 122 123def _save_csv(obj: Any, path: str | Path, **kwargs) -> None: 124 if not HAS_PANDAS: 125 _raise_package_not_installed("pandas", "csv") 126 import pandas as pd 127 128 if not isinstance(obj, (pd.DataFrame, pd.Series)): 129 _raise_wrong_type(path, "pandas DataFrame or Series") 130 obj.to_csv(path, **kwargs) 131 132 133def _save_keras(obj: Any, path: str | Path, **kwargs) -> None: 134 if not HAS_KERAS: 135 _raise_package_not_installed("keras", "keras") 136 137 import keras 138 139 if not isinstance(obj, keras.Model): 140 _raise_wrong_type(path, "keras model") 141 keras.saving.save_model(obj, path, **kwargs) 142 143 144def _save_npy(obj: Any, path: str | Path, **kwargs) -> None: 145 if not HAS_NUMPY: 146 _raise_package_not_installed("numpy", "npy") 147 import numpy as np 148 149 if not isinstance(obj, np.ndarray): 150 _raise_wrong_type(path, "numpy array") 151 np.save(str(path), obj, **kwargs) 152 153 154def _save_npz(obj: Any, path: str | Path, **kwargs) -> None: 155 if not HAS_NUMPY: 156 _raise_package_not_installed("numpy", "npz") 157 import numpy as np 158 159 if not _is_dict_of(obj, np.ndarray): 160 _raise_wrong_type(path, "dict of numpy arrays") 161 np.savez(str(path), **obj, **kwargs) 162 163 164def _save_pq(obj: Any, path: str | Path, **kwargs) -> None: 165 if not HAS_PANDAS: 166 _raise_package_not_installed("pandas", ".parquet/.pq") 167 import pandas as pd 168 169 if not isinstance(obj, pd.DataFrame): 170 _raise_wrong_type(path, "pandas DataFrame") 171 obj.to_parquet(path, **kwargs) 172 173 174def _save_pt(obj: Any, path: str | Path, **kwargs) -> None: 175 if not HAS_PYTORCH: 176 _raise_package_not_installed("torch", "pt") 177 import torch 178 179 if not (isinstance(obj, torch.Tensor) or _is_dict_of(obj, torch.Tensor)): 180 _raise_wrong_type(path, "torch tensor or a dict of torch tensors") 181 torch.save(obj, path, **kwargs) 182 183 184def _save_st(obj: Any, path: str | Path, **kwargs) -> None: 185 if not HAS_SAFETENSORS: 186 _raise_package_not_installed("safetensors", ".safetensors/.st") 187 import safetensors 188 189 if HAS_NUMPY: 190 import numpy as np 191 192 if _is_dict_of(obj, np.ndarray): 193 safetensors.numpy.save_file(obj, str(path), **kwargs) 194 return 195 196 if HAS_TENSORFLOW: 197 import tensorflow as tf 198 199 if _is_dict_of(obj, tf.Tensor): 200 safetensors.tensorflow.save_file(obj, str(path), **kwargs) 201 return 202 203 if HAS_PYTORCH: 204 import torch 205 206 if _is_dict_of(obj, torch.Tensor): 207 safetensors.torch.save_file(obj, str(path), **kwargs) 208 return 209 210 raise _raise_wrong_type( 211 path, 212 "dict of numpy arrays, a dict of tensorflow tensors, or a dict of " 213 "pytorch tensors", 214 ) 215 216 217def load(path: str | Path, **kwargs) -> Any: 218 """ 219 Loads an object from a file using format-specific (or "native") methods. 220 See `turbo_broccoli.native.save` for the list of supported file extensions. 221 222 Warning: 223 Safetensors files (`.st` or `.safetensors`) will be loaded as dicts of 224 numpy arrays even of the object was originally a dict of e.g. torch 225 tensors. 226 """ 227 extension = Path(path).suffix 228 methods: dict[str, Callable[[str | Path], Any]] = { 229 ".csv": _load_csv, 230 ".h5": _load_keras, 231 ".keras": _load_keras, 232 ".npy": _load_np, 233 ".npz": _load_np, 234 ".parquet": _load_pq, 235 ".pq": _load_pq, 236 ".pt": _load_pt, 237 ".st": _load_st, 238 ".tf": _load_keras, 239 } 240 method: Callable = methods.get(extension, load_json) 241 return method(path, **kwargs) 242 243 244def save(obj: Any, path: str | Path, **kwargs) -> None: 245 """ 246 Saves an object using the file extension of `path` to determine the 247 serialization/dumping method: 248 249 * `.csv`: 250 [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html) 251 or 252 [`pandas.Series.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.Series.to_csv.html) 253 * `.h5`: 254 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 255 with `save_format="h5"` 256 * `.keras`: 257 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 258 with `save_format="keras"` 259 * `.npy`: 260 [`numpy.save`](https://numpy.org/doc/stable/reference/generated/numpy.save.html) 261 * `.npz`: 262 [`numpy.savez`](https://numpy.org/doc/stable/reference/generated/numpy.savez.html) 263 * `.pq`, `.parquet`: 264 [`pandas.DataFrame.to_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_parquet.html) 265 * `.pt`: 266 [`torch.save`](https://pytorch.org/docs/stable/generated/torch.save.html) 267 * `.safetensors`, `.st`: (for numpy arrays, pytorch tensors and tensorflow tensors) 268 [safetensors](https://huggingface.co/docs/safetensors/index) 269 * `.tf`: 270 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 271 with `save_format="tf"` 272 * `.json` and anything else: just forwarded to `turbo_broccoli.save_json` 273 274 Args: 275 obj (Any): 276 path (str | Path): 277 kwargs: Passed to the serialization method 278 """ 279 extension = Path(path).suffix 280 methods: dict[str, Callable[[Any, str | Path], None]] = { 281 ".csv": _save_csv, 282 ".h5": partial(_save_keras, save_format="h5"), 283 ".keras": partial(_save_keras, save_format="keras"), 284 ".npy": _save_npy, 285 ".npz": _save_npz, 286 ".parquet": _save_pq, 287 ".pq": _save_pq, 288 ".pt": _save_pt, 289 ".st": _save_st, 290 ".tf": partial(_save_keras, save_format="tf"), 291 } 292 method = methods.get(extension, save_json) 293 method(obj, path, **kwargs)
def
load(path: str | pathlib.Path, **kwargs) -> Any:
218def load(path: str | Path, **kwargs) -> Any: 219 """ 220 Loads an object from a file using format-specific (or "native") methods. 221 See `turbo_broccoli.native.save` for the list of supported file extensions. 222 223 Warning: 224 Safetensors files (`.st` or `.safetensors`) will be loaded as dicts of 225 numpy arrays even of the object was originally a dict of e.g. torch 226 tensors. 227 """ 228 extension = Path(path).suffix 229 methods: dict[str, Callable[[str | Path], Any]] = { 230 ".csv": _load_csv, 231 ".h5": _load_keras, 232 ".keras": _load_keras, 233 ".npy": _load_np, 234 ".npz": _load_np, 235 ".parquet": _load_pq, 236 ".pq": _load_pq, 237 ".pt": _load_pt, 238 ".st": _load_st, 239 ".tf": _load_keras, 240 } 241 method: Callable = methods.get(extension, load_json) 242 return method(path, **kwargs)
Loads an object from a file using format-specific (or "native") methods.
See turbo_broccoli.native.save
for the list of supported file extensions.
Warning:
Safetensors files (.st
or .safetensors
) will be loaded as dicts of
numpy arrays even of the object was originally a dict of e.g. torch
tensors.
def
save(obj: Any, path: str | pathlib.Path, **kwargs) -> None:
245def save(obj: Any, path: str | Path, **kwargs) -> None: 246 """ 247 Saves an object using the file extension of `path` to determine the 248 serialization/dumping method: 249 250 * `.csv`: 251 [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html) 252 or 253 [`pandas.Series.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.Series.to_csv.html) 254 * `.h5`: 255 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 256 with `save_format="h5"` 257 * `.keras`: 258 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 259 with `save_format="keras"` 260 * `.npy`: 261 [`numpy.save`](https://numpy.org/doc/stable/reference/generated/numpy.save.html) 262 * `.npz`: 263 [`numpy.savez`](https://numpy.org/doc/stable/reference/generated/numpy.savez.html) 264 * `.pq`, `.parquet`: 265 [`pandas.DataFrame.to_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_parquet.html) 266 * `.pt`: 267 [`torch.save`](https://pytorch.org/docs/stable/generated/torch.save.html) 268 * `.safetensors`, `.st`: (for numpy arrays, pytorch tensors and tensorflow tensors) 269 [safetensors](https://huggingface.co/docs/safetensors/index) 270 * `.tf`: 271 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 272 with `save_format="tf"` 273 * `.json` and anything else: just forwarded to `turbo_broccoli.save_json` 274 275 Args: 276 obj (Any): 277 path (str | Path): 278 kwargs: Passed to the serialization method 279 """ 280 extension = Path(path).suffix 281 methods: dict[str, Callable[[Any, str | Path], None]] = { 282 ".csv": _save_csv, 283 ".h5": partial(_save_keras, save_format="h5"), 284 ".keras": partial(_save_keras, save_format="keras"), 285 ".npy": _save_npy, 286 ".npz": _save_npz, 287 ".parquet": _save_pq, 288 ".pq": _save_pq, 289 ".pt": _save_pt, 290 ".st": _save_st, 291 ".tf": partial(_save_keras, save_format="tf"), 292 } 293 method = methods.get(extension, save_json) 294 method(obj, path, **kwargs)
Saves an object using the file extension of path
to determine the
serialization/dumping method:
.csv
:pandas.DataFrame.to_csv
orpandas.Series.to_csv
.h5
:tf.keras.saving.save_model
withsave_format="h5"
.keras
:tf.keras.saving.save_model
withsave_format="keras"
.npy
:numpy.save
.npz
:numpy.savez
.pq
,.parquet
:pandas.DataFrame.to_parquet
.pt
:torch.save
.safetensors
,.st
: (for numpy arrays, pytorch tensors and tensorflow tensors) safetensors.tf
:tf.keras.saving.save_model
withsave_format="tf"
.json
and anything else: just forwarded toturbo_broccoli.save_json
Args: obj (Any): path (str | Path): kwargs: Passed to the serialization method