turbo_broccoli.custom.sklearn

Scikit-learn estimators

  1"""Scikit-learn estimators"""
  2
  3from typing import Any, Callable, Tuple
  4
  5# Sklearn recommends joblib rather than direct pickle
  6# https://scikit-learn.org/stable/model_persistence.html#python-specific-serialization
  7import joblib
  8from sklearn import (
  9    calibration,
 10    cluster,
 11    compose,
 12    covariance,
 13    cross_decomposition,
 14    datasets,
 15    decomposition,
 16    discriminant_analysis,
 17    dummy,
 18    ensemble,
 19    exceptions,
 20    feature_extraction,
 21    feature_selection,
 22    gaussian_process,
 23    impute,
 24    inspection,
 25    isotonic,
 26    kernel_approximation,
 27    kernel_ridge,
 28    linear_model,
 29    manifold,
 30    metrics,
 31    mixture,
 32    model_selection,
 33    multiclass,
 34    multioutput,
 35    naive_bayes,
 36    neighbors,
 37    neural_network,
 38    pipeline,
 39    preprocessing,
 40    random_projection,
 41    semi_supervised,
 42    svm,
 43    tree,
 44)
 45from sklearn.base import BaseEstimator
 46from sklearn.tree._tree import Tree
 47
 48from ..context import Context
 49from ..exceptions import DeserializationError, TypeNotSupported
 50
 51_SKLEARN_SUBMODULES = [
 52    # calibration,
 53    cluster,
 54    covariance,
 55    cross_decomposition,
 56    datasets,
 57    decomposition,
 58    # dummy,
 59    ensemble,
 60    exceptions,
 61    # experimental,
 62    # externals,
 63    feature_extraction,
 64    feature_selection,
 65    gaussian_process,
 66    inspection,
 67    isotonic,
 68    # kernel_approximation,
 69    # kernel_ridge,
 70    linear_model,
 71    manifold,
 72    metrics,
 73    mixture,
 74    model_selection,
 75    multiclass,
 76    multioutput,
 77    naive_bayes,
 78    neighbors,
 79    neural_network,
 80    pipeline,
 81    preprocessing,
 82    random_projection,
 83    semi_supervised,
 84    svm,
 85    tree,
 86    discriminant_analysis,
 87    impute,
 88    compose,
 89]
 90
 91_SKLEARN_TREE_ATTRIBUTES = [
 92    "capacity",
 93    "children_left",
 94    "children_right",
 95    "feature",
 96    "impurity",
 97    "max_depth",
 98    "max_n_classes",
 99    "n_classes",
100    "n_features",
101    "n_leaves",
102    "n_node_samples",
103    "n_outputs",
104    "node_count",
105    "threshold",
106    "value",
107    "weighted_n_node_samples",
108]
109
110_SUPPORTED_PICKLABLE_TYPES = [
111    tree._tree.Tree,
112    neighbors.KDTree,
113]
114"""sklearn types that shall be pickled"""
115
116
117def _all_base_estimators() -> dict[str, type]:
118    """
119    Returns (hopefully) all classes of sklearn that inherit from
120    `BaseEstimator`
121    """
122    result = []
123    for s in _SKLEARN_SUBMODULES:
124        if not hasattr(s, "__all__"):
125            continue
126        s_all = getattr(s, "__all__")
127        if not isinstance(s_all, list):
128            continue
129        for k in s_all:
130            cls = getattr(s, k)
131            if isinstance(cls, type) and issubclass(cls, BaseEstimator):
132                result.append(cls)
133    # Some sklearn submodules don't have __all__
134    result += [
135        calibration.CalibratedClassifierCV,
136        dummy.DummyClassifier,
137        dummy.DummyRegressor,
138        kernel_approximation.PolynomialCountSketch,
139        kernel_approximation.RBFSampler,
140        kernel_approximation.SkewedChi2Sampler,
141        kernel_approximation.AdditiveChi2Sampler,
142        kernel_approximation.Nystroem,
143        kernel_ridge.KernelRidge,
144    ]
145    return {cls.__name__: cls for cls in result}
146
147
148def _sklearn_estimator_to_json(obj: BaseEstimator, ctx: Context) -> dict:
149    return {
150        "__type__": "sklearn.estimator." + obj.__class__.__name__,
151        "__version__": 2,
152        "params": obj.get_params(deep=False),
153        "attrs": obj.__dict__,
154    }
155
156
157def _sklearn_to_raw(obj: Any, ctx: Context) -> dict:
158    """
159    Pickles an otherwise unserializable sklearn object. Actually uses the
160    `joblib.dump`.
161
162    TODO:
163        Don't dump to file if the object is small enough. Unfortunately
164        `joblib` can't dump to a string.
165    """
166    path, name = ctx.new_artifact_path()
167    joblib.dump(obj, path)
168    return {
169        "__type__": "sklearn.raw",
170        "__version__": 2,
171        "data": name,
172    }
173
174
175def _sklearn_tree_to_json(obj: Tree, ctx: Context) -> dict:
176    return {
177        "__type__": "sklearn.tree",
178        "__version__": 2,
179        **{a: getattr(obj, a) for a in _SKLEARN_TREE_ATTRIBUTES},
180    }
181
182
183def _json_raw_to_sklearn(dct: dict, ctx: Context) -> Any:
184    decoders = {
185        # 1: _json_raw_to_sklearn_v1,  # Use turbo_broccoli v3
186        2: _json_raw_to_sklearn_v2,
187    }
188    return decoders[dct["__version__"]](dct, ctx)
189
190
191def _json_raw_to_sklearn_v2(dct: dict, ctx: Context) -> Any:
192    return joblib.load(ctx.id_to_artifact_path(dct["data"]))
193
194
195def _json_to_sklearn_estimator(dct: dict, ctx: Context) -> BaseEstimator:
196    decoders = {
197        2: _json_to_sklearn_estimator_v2,
198    }
199    return decoders[dct["__version__"]](dct, ctx)
200
201
202def _json_to_sklearn_estimator_v2(dct: dict, ctx: Context) -> BaseEstimator:
203    bes = _all_base_estimators()
204    cls = bes[dct["__type__"].split(".")[-1]]
205    obj = cls(**dct["params"])
206    for k, v in dct["attrs"].items():
207        setattr(obj, k, v)
208    return obj
209
210
211def from_json(dct: dict, ctx: Context) -> BaseEstimator:
212    decoders = {  # Except sklearn estimators
213        "sklearn.raw": _json_raw_to_sklearn,
214    }
215    try:
216        type_name = dct["__type__"]
217        if type_name.startswith("sklearn.estimator."):
218            return _json_to_sklearn_estimator(dct, ctx)
219        return decoders[type_name](dct, ctx)
220    except KeyError as exc:
221        raise DeserializationError() from exc
222
223
224def to_json(obj: BaseEstimator, ctx: Context) -> dict:
225    """
226    Serializes a sklearn estimator into JSON by cases. See the README for the
227    precise list of supported types. The return dict has the following
228    structure:
229
230    - if the object is an estimator:
231
232        ```py
233        {
234            "__type__": "sklearn.estimator.<CLASS NAME>",
235            "__version__": 2,
236            "params": <dict returned by get_params(deep=False)>,
237            "attrs": {...}
238        }
239        ```
240
241      where the `attrs` dict contains all the attributes of the estimator as
242      specified in the sklearn API documentation.
243
244    - otherwise:
245
246        ```py
247        {
248            "__type__": "sklearn.raw",
249            "__version__": 2,
250            "data": <uuid4>
251        }
252        ```
253
254      where the UUID4 value points to an pickle file artifact.
255    """
256
257    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
258        (t, _sklearn_to_raw) for t in _SUPPORTED_PICKLABLE_TYPES
259    ] + [
260        (BaseEstimator, _sklearn_estimator_to_json),
261    ]
262    for t, f in encoders:
263        if isinstance(obj, t):
264            return f(obj, ctx)
265    raise TypeNotSupported()
def from_json( dct: dict, ctx: turbo_broccoli.context.Context) -> sklearn.base.BaseEstimator:
212def from_json(dct: dict, ctx: Context) -> BaseEstimator:
213    decoders = {  # Except sklearn estimators
214        "sklearn.raw": _json_raw_to_sklearn,
215    }
216    try:
217        type_name = dct["__type__"]
218        if type_name.startswith("sklearn.estimator."):
219            return _json_to_sklearn_estimator(dct, ctx)
220        return decoders[type_name](dct, ctx)
221    except KeyError as exc:
222        raise DeserializationError() from exc
def to_json( obj: sklearn.base.BaseEstimator, ctx: turbo_broccoli.context.Context) -> dict:
225def to_json(obj: BaseEstimator, ctx: Context) -> dict:
226    """
227    Serializes a sklearn estimator into JSON by cases. See the README for the
228    precise list of supported types. The return dict has the following
229    structure:
230
231    - if the object is an estimator:
232
233        ```py
234        {
235            "__type__": "sklearn.estimator.<CLASS NAME>",
236            "__version__": 2,
237            "params": <dict returned by get_params(deep=False)>,
238            "attrs": {...}
239        }
240        ```
241
242      where the `attrs` dict contains all the attributes of the estimator as
243      specified in the sklearn API documentation.
244
245    - otherwise:
246
247        ```py
248        {
249            "__type__": "sklearn.raw",
250            "__version__": 2,
251            "data": <uuid4>
252        }
253        ```
254
255      where the UUID4 value points to an pickle file artifact.
256    """
257
258    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
259        (t, _sklearn_to_raw) for t in _SUPPORTED_PICKLABLE_TYPES
260    ] + [
261        (BaseEstimator, _sklearn_estimator_to_json),
262    ]
263    for t, f in encoders:
264        if isinstance(obj, t):
265            return f(obj, ctx)
266    raise TypeNotSupported()

Serializes a sklearn estimator into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure:

  • if the object is an estimator:

    {
        "__type__": "sklearn.estimator.<CLASS NAME>",
        "__version__": 2,
        "params": <dict returned by get_params(deep=False)>,
        "attrs": {...}
    }
    

    where the attrs dict contains all the attributes of the estimator as specified in the sklearn API documentation.

  • otherwise:

    {
        "__type__": "sklearn.raw",
        "__version__": 2,
        "data": <uuid4>
    }
    

    where the UUID4 value points to an pickle file artifact.