turbo_broccoli.custom.sklearn
Scikit-learn estimators
1"""Scikit-learn estimators""" 2 3from typing import Any, Callable, Tuple 4 5# Sklearn recommends joblib rather than direct pickle 6# https://scikit-learn.org/stable/model_persistence.html#python-specific-serialization 7import joblib 8from sklearn import ( 9 calibration, 10 cluster, 11 compose, 12 covariance, 13 cross_decomposition, 14 datasets, 15 decomposition, 16 discriminant_analysis, 17 dummy, 18 ensemble, 19 exceptions, 20 feature_extraction, 21 feature_selection, 22 gaussian_process, 23 impute, 24 inspection, 25 isotonic, 26 kernel_approximation, 27 kernel_ridge, 28 linear_model, 29 manifold, 30 metrics, 31 mixture, 32 model_selection, 33 multiclass, 34 multioutput, 35 naive_bayes, 36 neighbors, 37 neural_network, 38 pipeline, 39 preprocessing, 40 random_projection, 41 semi_supervised, 42 svm, 43 tree, 44) 45from sklearn.base import BaseEstimator 46from sklearn.tree._tree import Tree 47 48from ..context import Context 49from ..exceptions import DeserializationError, TypeNotSupported 50 51_SKLEARN_SUBMODULES = [ 52 # calibration, 53 cluster, 54 covariance, 55 cross_decomposition, 56 datasets, 57 decomposition, 58 # dummy, 59 ensemble, 60 exceptions, 61 # experimental, 62 # externals, 63 feature_extraction, 64 feature_selection, 65 gaussian_process, 66 inspection, 67 isotonic, 68 # kernel_approximation, 69 # kernel_ridge, 70 linear_model, 71 manifold, 72 metrics, 73 mixture, 74 model_selection, 75 multiclass, 76 multioutput, 77 naive_bayes, 78 neighbors, 79 neural_network, 80 pipeline, 81 preprocessing, 82 random_projection, 83 semi_supervised, 84 svm, 85 tree, 86 discriminant_analysis, 87 impute, 88 compose, 89] 90 91_SKLEARN_TREE_ATTRIBUTES = [ 92 "capacity", 93 "children_left", 94 "children_right", 95 "feature", 96 "impurity", 97 "max_depth", 98 "max_n_classes", 99 "n_classes", 100 "n_features", 101 "n_leaves", 102 "n_node_samples", 103 "n_outputs", 104 "node_count", 105 "threshold", 106 "value", 107 "weighted_n_node_samples", 108] 109 110_SUPPORTED_PICKLABLE_TYPES = [ 111 tree._tree.Tree, 112 neighbors.KDTree, 113] 114"""sklearn types that shall be pickled""" 115 116 117def _all_base_estimators() -> dict[str, type]: 118 """ 119 Returns (hopefully) all classes of sklearn that inherit from 120 `BaseEstimator` 121 """ 122 result = [] 123 for s in _SKLEARN_SUBMODULES: 124 if not hasattr(s, "__all__"): 125 continue 126 s_all = getattr(s, "__all__") 127 if not isinstance(s_all, list): 128 continue 129 for k in s_all: 130 cls = getattr(s, k) 131 if isinstance(cls, type) and issubclass(cls, BaseEstimator): 132 result.append(cls) 133 # Some sklearn submodules don't have __all__ 134 result += [ 135 calibration.CalibratedClassifierCV, 136 dummy.DummyClassifier, 137 dummy.DummyRegressor, 138 kernel_approximation.PolynomialCountSketch, 139 kernel_approximation.RBFSampler, 140 kernel_approximation.SkewedChi2Sampler, 141 kernel_approximation.AdditiveChi2Sampler, 142 kernel_approximation.Nystroem, 143 kernel_ridge.KernelRidge, 144 ] 145 return {cls.__name__: cls for cls in result} 146 147 148def _sklearn_estimator_to_json(obj: BaseEstimator, ctx: Context) -> dict: 149 return { 150 "__type__": "sklearn.estimator." + obj.__class__.__name__, 151 "__version__": 2, 152 "params": obj.get_params(deep=False), 153 "attrs": obj.__dict__, 154 } 155 156 157def _sklearn_to_raw(obj: Any, ctx: Context) -> dict: 158 """ 159 Pickles an otherwise unserializable sklearn object. Actually uses the 160 `joblib.dump`. 161 162 TODO: 163 Don't dump to file if the object is small enough. Unfortunately 164 `joblib` can't dump to a string. 165 """ 166 path, name = ctx.new_artifact_path() 167 joblib.dump(obj, path) 168 return { 169 "__type__": "sklearn.raw", 170 "__version__": 2, 171 "data": name, 172 } 173 174 175def _sklearn_tree_to_json(obj: Tree, ctx: Context) -> dict: 176 return { 177 "__type__": "sklearn.tree", 178 "__version__": 2, 179 **{a: getattr(obj, a) for a in _SKLEARN_TREE_ATTRIBUTES}, 180 } 181 182 183def _json_raw_to_sklearn(dct: dict, ctx: Context) -> Any: 184 decoders = { 185 # 1: _json_raw_to_sklearn_v1, # Use turbo_broccoli v3 186 2: _json_raw_to_sklearn_v2, 187 } 188 return decoders[dct["__version__"]](dct, ctx) 189 190 191def _json_raw_to_sklearn_v2(dct: dict, ctx: Context) -> Any: 192 return joblib.load(ctx.id_to_artifact_path(dct["data"])) 193 194 195def _json_to_sklearn_estimator(dct: dict, ctx: Context) -> BaseEstimator: 196 decoders = { 197 2: _json_to_sklearn_estimator_v2, 198 } 199 return decoders[dct["__version__"]](dct, ctx) 200 201 202def _json_to_sklearn_estimator_v2(dct: dict, ctx: Context) -> BaseEstimator: 203 bes = _all_base_estimators() 204 cls = bes[dct["__type__"].split(".")[-1]] 205 obj = cls(**dct["params"]) 206 for k, v in dct["attrs"].items(): 207 setattr(obj, k, v) 208 return obj 209 210 211def from_json(dct: dict, ctx: Context) -> BaseEstimator: 212 decoders = { # Except sklearn estimators 213 "sklearn.raw": _json_raw_to_sklearn, 214 } 215 try: 216 type_name = dct["__type__"] 217 if type_name.startswith("sklearn.estimator."): 218 return _json_to_sklearn_estimator(dct, ctx) 219 return decoders[type_name](dct, ctx) 220 except KeyError as exc: 221 raise DeserializationError() from exc 222 223 224def to_json(obj: BaseEstimator, ctx: Context) -> dict: 225 """ 226 Serializes a sklearn estimator into JSON by cases. See the README for the 227 precise list of supported types. The return dict has the following 228 structure: 229 230 - if the object is an estimator: 231 232 ```py 233 { 234 "__type__": "sklearn.estimator.<CLASS NAME>", 235 "__version__": 2, 236 "params": <dict returned by get_params(deep=False)>, 237 "attrs": {...} 238 } 239 ``` 240 241 where the `attrs` dict contains all the attributes of the estimator as 242 specified in the sklearn API documentation. 243 244 - otherwise: 245 246 ```py 247 { 248 "__type__": "sklearn.raw", 249 "__version__": 2, 250 "data": <uuid4> 251 } 252 ``` 253 254 where the UUID4 value points to an pickle file artifact. 255 """ 256 257 encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [ 258 (t, _sklearn_to_raw) for t in _SUPPORTED_PICKLABLE_TYPES 259 ] + [ 260 (BaseEstimator, _sklearn_estimator_to_json), 261 ] 262 for t, f in encoders: 263 if isinstance(obj, t): 264 return f(obj, ctx) 265 raise TypeNotSupported()
212def from_json(dct: dict, ctx: Context) -> BaseEstimator: 213 decoders = { # Except sklearn estimators 214 "sklearn.raw": _json_raw_to_sklearn, 215 } 216 try: 217 type_name = dct["__type__"] 218 if type_name.startswith("sklearn.estimator."): 219 return _json_to_sklearn_estimator(dct, ctx) 220 return decoders[type_name](dct, ctx) 221 except KeyError as exc: 222 raise DeserializationError() from exc
225def to_json(obj: BaseEstimator, ctx: Context) -> dict: 226 """ 227 Serializes a sklearn estimator into JSON by cases. See the README for the 228 precise list of supported types. The return dict has the following 229 structure: 230 231 - if the object is an estimator: 232 233 ```py 234 { 235 "__type__": "sklearn.estimator.<CLASS NAME>", 236 "__version__": 2, 237 "params": <dict returned by get_params(deep=False)>, 238 "attrs": {...} 239 } 240 ``` 241 242 where the `attrs` dict contains all the attributes of the estimator as 243 specified in the sklearn API documentation. 244 245 - otherwise: 246 247 ```py 248 { 249 "__type__": "sklearn.raw", 250 "__version__": 2, 251 "data": <uuid4> 252 } 253 ``` 254 255 where the UUID4 value points to an pickle file artifact. 256 """ 257 258 encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [ 259 (t, _sklearn_to_raw) for t in _SUPPORTED_PICKLABLE_TYPES 260 ] + [ 261 (BaseEstimator, _sklearn_estimator_to_json), 262 ] 263 for t, f in encoders: 264 if isinstance(obj, t): 265 return f(obj, ctx) 266 raise TypeNotSupported()
Serializes a sklearn estimator into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure:
if the object is an estimator:
{ "__type__": "sklearn.estimator.<CLASS NAME>", "__version__": 2, "params": <dict returned by get_params(deep=False)>, "attrs": {...} }
where the
attrs
dict contains all the attributes of the estimator as specified in the sklearn API documentation.otherwise:
{ "__type__": "sklearn.raw", "__version__": 2, "data": <uuid4> }
where the UUID4 value points to an pickle file artifact.