lcc.classifiers

Classifier models and related stuff

 1"""Classifier models and related stuff"""
 2
 3from .base import (
 4    BaseClassifier,
 5    LatentClusteringData,
 6    validate_lcc_kwargs,
 7)
 8from .huggingface import HuggingFaceClassifier
 9from .timm import TimmClassifier
10from .torchvision import TorchvisionClassifier
11from .wrapped import WrappedClassifier
12
13__all__ = [
14    "BaseClassifier",
15    "get_classifier_cls",
16    "HuggingFaceClassifier",
17    "LatentClusteringData",
18    "TimmClassifier",
19    "TorchvisionClassifier",
20    "validate_lcc_kwargs",
21    "WrappedClassifier",
22]
23
24
25def get_classifier_cls(model_name: str) -> type[BaseClassifier]:
26    """
27    Returns the classifier class to use for a given model name.
28
29    Args:
30        model_name (str):
31    """
32    if model_name.startswith("timm/"):
33        return TimmClassifier
34    if "/" in model_name:
35        return HuggingFaceClassifier
36    return TorchvisionClassifier
class BaseClassifier(pytorch_lightning.core.module.LightningModule):
 49class BaseClassifier(pl.LightningModule):
 50    """
 51    Base image classifier class that supports LCC.
 52
 53    Warning:
 54        When subclassing this, remember that the forward method must be able to
 55        deal with either `Tensor` or `Batch` inputs, and must return a logit
 56        `Tensor`.
 57    """
 58
 59    lcc_data: dict[str, LatentClusteringData] | None = None
 60    """
 61    If LCC is applied, then this is non `None` and updated at the begining of
 62    each epoch. See also `full_dataset_latent_clustering`.
 63    """
 64
 65    standard_loss: nn.Module
 66    """'Standard' loss to use together with LCC."""
 67
 68    accuracy_top1: nn.Module
 69    """Top-1 accuracy metric."""
 70
 71    accuracy_top5: nn.Module | None = None
 72    """Top-5 accuracy metric."""
 73
 74    def __init__(
 75        self,
 76        n_classes: int,
 77        lcc_submodules: list[str] | None = None,
 78        lcc_kwargs: dict | None = None,
 79        ce_weight: float = 1,
 80        image_key: Any = 0,
 81        label_key: Any = 1,
 82        **kwargs: Any,
 83    ) -> None:
 84        """
 85        Args:
 86            n_classes (int):
 87            lcc_submodules (list[str] | None, optional): Submodules to consider
 88                for the latent correction loss. If `None` or `[]`, LCC is not
 89                performed
 90            lcc_kwargs (dict | None, optional): Optional parameters for LCC.
 91                Expected entries (all optional) are:
 92                * **weight (float):** Defaults to $10^{-4}$
 93                * **class_selection
 94                    (`lcc.correction.LCCClassSelection` | None):** Defaults to
 95                    `None`, which means all classes are considered for
 96                    correction
 97                * **interval (int):** Apply LCC every `interval` epochs.
 98                    Defaults to $1$, meaning LCC will be applied every epoch
 99                    (after warmup).
100                * **warmup (int):** Number of epochs to wait before
101                    starting LCC. Defaults to $0$, meaning LCC will start
102                    immediately.
103                * **k (int):** Number of nearest neighbors to consider for LCC,
104                  and Louvain clustering.
105                * **pca_dim (int):** Samples are reduced to this dimension
106                  before constructing the KNN graph. This must be at most the
107                  batch size.
108                * **loss (`"exact"` or `"randomized"`)**: Way the LCC loss is
109                  computed from the clustering data. See `lcc.correction.loss`.
110                * **ccspc (int)**: If the loss type is `"randomized"`, then this
111                  parameter specified the number of CC samples per cluster to
112                  keep as potential correction targets. See also
113                  `lcc.correction.loss.RandomizedLCCLoss`.
114                * **clustering_method (`"louvain"` or `"peer_pressure"`)**:
115                  Clustering algorithm for the dataset of latent
116                  representations.
117            ce_weight (float, optional): Weight of the cross-entropy loss in the
118                clustering-CE loss. Ignored if LCC is not applied. Defaults to
119                $1$.
120            image_key (Any, optional): A batch passed to the model can be a
121                tuple (most common) or a dict. This parameter specifies the key
122                to use to retrieve the input tensor.
123            label_key (Any, optional): Analogous to `image_key`.
124            kwargs: Forwarded to
125                [`pl.LightningModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#)
126        """
127        super().__init__(**kwargs)
128        self.save_hyperparameters(
129            "ce_weight",
130            "image_key",
131            "label_key",
132            "lcc_kwargs",
133            "lcc_submodules",
134            "n_classes",
135        )
136        if lcc_submodules:
137            validate_lcc_kwargs(lcc_kwargs)
138        self.standard_loss = torch.nn.CrossEntropyLoss()
139        acc_kw: dict[str, Any] = {
140            "task": "multiclass",
141            "num_classes": n_classes,
142            "average": "micro",
143        }
144        self.accuracy_top1 = tm.Accuracy(top_k=1, **acc_kw)  # type: ignore
145        if n_classes > 5:
146            self.accuracy_top5 = tm.Accuracy(top_k=5, **acc_kw)  # type: ignore
147
148    def _evaluate(self, batch: Batch, stage: str | None = None) -> Tensor:
149        """Self-explanatory"""
150        image_key = self.hparams["image_key"]
151        label_key = self.hparams["label_key"]
152        x, y = batch[image_key], batch[label_key].to(self.device)
153        latent: dict[str, Tensor] = {}
154        logits = self.forward_intermediate(
155            x, self.lcc_submodules, latent, keep_gradients=True
156        )
157        assert isinstance(logits, Tensor)
158        loss_ce = self.standard_loss(logits, y)
159        if self.lcc_data and stage == "train":
160            idx = to_array(batch["_idx"])
161            _losses = [
162                self.lcc_data[sm].loss(
163                    z, y_true=y, y_clst=self.lcc_data[sm].y_clst[idx]
164                )
165                for sm, z in latent.items()
166            ]
167            loss_lcc = (
168                torch.stack(_losses).mean()
169                if _losses
170                else torch.tensor(0.0, requires_grad=True)
171                # ↑ actually need grad?
172            )
173            lcc_weight = self.hparams.get("lcc_kwargs", {}).get("weight", 1e-4)
174            loss = self.hparams["ce_weight"] * loss_ce + lcc_weight * loss_lcc
175            self.log(f"{stage}/lcc", loss_lcc, sync_dist=True)
176        else:
177            loss = loss_ce
178        if stage:
179            d = {
180                f"{stage}/loss": loss,
181                f"{stage}/ce": loss_ce,
182            }
183            if self.accuracy_top5 is not None:
184                d[f"{stage}/acc5"] = self.accuracy_top5(logits, y)
185            self.log_dict(d, sync_dist=True)
186            self.log(
187                f"{stage}/acc",
188                self.accuracy_top1(logits, y),
189                prog_bar=True,
190                sync_dist=True,
191            )
192        return loss  # type: ignore
193
194    def configure_optimizers(self) -> Any:
195        optimizer = torch.optim.SGD(self.parameters(), lr=1e-1, momentum=0.9)
196        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
197            optimizer, T_max=50
198        )
199        return {
200            "optimizer": optimizer,
201            "lr_scheduler": {"scheduler": scheduler},
202        }
203
204    def forward_intermediate(
205        self,
206        inputs: Tensor | Batch | list[Tensor] | Sequence[Batch],
207        submodules: list[str],
208        output_dict: dict,
209        keep_gradients: bool = False,
210    ) -> Tensor | list[Tensor]:
211        """
212        Runs the model and collects the output of specified submodules. The
213        intermediate outputs are stored in `output_dict` under the
214        corresponding submodule name. In particular, this method has side
215        effects.
216
217        Args:
218            x (Tensor | Batch | list[Tensor] | list[Batch]): If batched (i.e.
219                `x` is a list), then so is the output of this function and the
220                entries in the `output_dict`
221            submodules (list[str]):
222            output_dict (dict):
223            keep_gradients (bool, optional): If `True`, the tensors in
224                `output_dict` keep their gradients (if they had some on the
225                first place). If `False`, they are detached and moved to the
226                CPU.
227        """
228
229        def maybe_detach(x: Tensor) -> Tensor:
230            return x if keep_gradients else x.detach().cpu()
231
232        def create_hook(key: str) -> Callable[[nn.Module, Any, Any], None]:
233            def hook(_model: nn.Module, _args: Any, out: Any) -> None:
234                if (
235                    isinstance(out, (list, tuple))
236                    and len(out) == 1
237                    and isinstance(out[0], Tensor)
238                ):
239                    out = out[0]
240                elif (  # Special case for ViTs
241                    isinstance(out, (list, tuple))
242                    and len(out) == 2
243                    and isinstance(out[0], Tensor)
244                    and not isinstance(out[1], Tensor)
245                ):
246                    out = out[0]
247                elif not isinstance(out, Tensor):
248                    raise ValueError(
249                        f"Unsupported latent object type: {type(out)}: {out}."
250                    )
251                if batched:
252                    if key not in output_dict:
253                        output_dict[key] = []
254                    output_dict[key].append(maybe_detach(out))
255                else:
256                    output_dict[key] = maybe_detach(out)
257
258            return hook
259
260        batched = isinstance(inputs, (list, tuple))
261        handles: list[RemovableHandle] = []
262        for name in submodules:
263            submodule = self.get_submodule(name)
264            handles.append(submodule.register_forward_hook(create_hook(name)))
265        if batched:
266            logits = [
267                maybe_detach(
268                    self.forward(
269                        batch
270                        if isinstance(batch, Tensor)
271                        else batch[self.hparams["image_key"]]
272                    )
273                )
274                for batch in inputs
275            ]
276        else:
277            logits = maybe_detach(  # type: ignore
278                self.forward(
279                    inputs
280                    if isinstance(inputs, Tensor)
281                    else inputs[self.hparams["image_key"]]
282                )
283            )
284        for h in handles:
285            h.remove()
286        return logits
287
288    @staticmethod
289    def get_image_processor(model_name: str, **kwargs: Any) -> Callable:
290        """
291        Returns an image processor for the model. By defaults, returns the
292        identity function.
293        """
294        return lambda input: input
295
296    @property
297    def lcc_submodules(self) -> list[str]:
298        """
299        Returns the list of submodules considered for LCC, whith correct prefix
300        if needed.
301        """
302        return self.hparams.get("lcc_submodules") or []
303
304    def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
305        """Just logs all optimizer's learning rate"""
306        log_optimizers_lr(self, sync_dist=True)
307        super().on_train_batch_end(*args, **kwargs)
308
309    def on_train_epoch_end(self) -> None:
310        """Cleans up training specific temporary attributes"""
311        self.lcc_data = None
312        super().on_train_epoch_end()
313
314    def on_train_epoch_start(self) -> None:
315        """
316        Performs dataset-wide latent clustering and stores the results in
317        private attribute `BaseClassifier._lc_data`.
318        """
319        # wether to apply LCC this epoch
320        lcc_kwargs = self.hparams.get("lcc_kwargs") or {}
321        do_lcc = (
322            # we are passed warmup (lcc_warmup being None is equivalent to no
323            # warmup)...
324            self.current_epoch >= (lcc_kwargs.get("warmup") or 0)
325            and (
326                # ... and an LCC interval is specified...
327                lcc_kwargs.get("interval") is not None
328                # ... and the current epoch can have LCC done...
329                and self.current_epoch % int(lcc_kwargs.get("interval", 1))
330                == 0
331            )
332            # ... and there are submodule selected for LCC...
333            and self.lcc_submodules
334            # ... and the LCC weight is non-zero
335            and lcc_kwargs.get("weight", 0) > 0
336        )
337        if do_lcc:
338            # The import has to be here to prevent circular imports while having
339            # all the FDLC logic in a separate file
340            from .fdlc import full_dataset_latent_clustering
341
342            with temporary_directory(self) as tmp_path:
343                self.lcc_data = full_dataset_latent_clustering(
344                    model=self,
345                    output_dir=tmp_path,
346                    tqdm_style="console",
347                )
348        super().on_train_epoch_start()
349
350    def on_train_start(self) -> None:
351        """
352        Explicitly registers hyperparameters and metrics. You'd think Lightning
353        would do this automatically, but nope.
354        """
355        self.logger.log_hyperparams(  # type: ignore
356            self.hparams,
357            {
358                s + "/" + m: np.nan
359                for s, m in product(
360                    ["train", "val"], ["acc", "loss", "ce", "lcc"]
361                )
362            },
363        )
364        super().on_train_start()
365
366    def test_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor:
367        """Override from `pl.LightningModule.test_step`."""
368        return self._evaluate(batch, "test")
369
370    def training_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor:
371        """Override from `pl.LightningModule.training_step`."""
372        return self._evaluate(batch, "train")
373
374    def validation_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor:
375        """Override from `pl.LightningModule.validation_step`."""
376        return self._evaluate(batch, "val")

Base image classifier class that supports LCC.

Warning:

When subclassing this, remember that the forward method must be able to deal with either Tensor or Batch inputs, and must return a logit Tensor.

BaseClassifier( n_classes: int, lcc_submodules: list[str] | None = None, lcc_kwargs: dict | None = None, ce_weight: float = 1, image_key: Any = 0, label_key: Any = 1, **kwargs: Any)
 74    def __init__(
 75        self,
 76        n_classes: int,
 77        lcc_submodules: list[str] | None = None,
 78        lcc_kwargs: dict | None = None,
 79        ce_weight: float = 1,
 80        image_key: Any = 0,
 81        label_key: Any = 1,
 82        **kwargs: Any,
 83    ) -> None:
 84        """
 85        Args:
 86            n_classes (int):
 87            lcc_submodules (list[str] | None, optional): Submodules to consider
 88                for the latent correction loss. If `None` or `[]`, LCC is not
 89                performed
 90            lcc_kwargs (dict | None, optional): Optional parameters for LCC.
 91                Expected entries (all optional) are:
 92                * **weight (float):** Defaults to $10^{-4}$
 93                * **class_selection
 94                    (`lcc.correction.LCCClassSelection` | None):** Defaults to
 95                    `None`, which means all classes are considered for
 96                    correction
 97                * **interval (int):** Apply LCC every `interval` epochs.
 98                    Defaults to $1$, meaning LCC will be applied every epoch
 99                    (after warmup).
100                * **warmup (int):** Number of epochs to wait before
101                    starting LCC. Defaults to $0$, meaning LCC will start
102                    immediately.
103                * **k (int):** Number of nearest neighbors to consider for LCC,
104                  and Louvain clustering.
105                * **pca_dim (int):** Samples are reduced to this dimension
106                  before constructing the KNN graph. This must be at most the
107                  batch size.
108                * **loss (`"exact"` or `"randomized"`)**: Way the LCC loss is
109                  computed from the clustering data. See `lcc.correction.loss`.
110                * **ccspc (int)**: If the loss type is `"randomized"`, then this
111                  parameter specified the number of CC samples per cluster to
112                  keep as potential correction targets. See also
113                  `lcc.correction.loss.RandomizedLCCLoss`.
114                * **clustering_method (`"louvain"` or `"peer_pressure"`)**:
115                  Clustering algorithm for the dataset of latent
116                  representations.
117            ce_weight (float, optional): Weight of the cross-entropy loss in the
118                clustering-CE loss. Ignored if LCC is not applied. Defaults to
119                $1$.
120            image_key (Any, optional): A batch passed to the model can be a
121                tuple (most common) or a dict. This parameter specifies the key
122                to use to retrieve the input tensor.
123            label_key (Any, optional): Analogous to `image_key`.
124            kwargs: Forwarded to
125                [`pl.LightningModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#)
126        """
127        super().__init__(**kwargs)
128        self.save_hyperparameters(
129            "ce_weight",
130            "image_key",
131            "label_key",
132            "lcc_kwargs",
133            "lcc_submodules",
134            "n_classes",
135        )
136        if lcc_submodules:
137            validate_lcc_kwargs(lcc_kwargs)
138        self.standard_loss = torch.nn.CrossEntropyLoss()
139        acc_kw: dict[str, Any] = {
140            "task": "multiclass",
141            "num_classes": n_classes,
142            "average": "micro",
143        }
144        self.accuracy_top1 = tm.Accuracy(top_k=1, **acc_kw)  # type: ignore
145        if n_classes > 5:
146            self.accuracy_top5 = tm.Accuracy(top_k=5, **acc_kw)  # type: ignore
Arguments:
  • n_classes (int):
  • lcc_submodules (list[str] | None, optional): Submodules to consider for the latent correction loss. If None or [], LCC is not performed
  • lcc_kwargs (dict | None, optional): Optional parameters for LCC. Expected entries (all optional) are:
    • weight (float): Defaults to $10^{-4}$
    • class_selection (lcc.correction.LCCClassSelection | None): Defaults to None, which means all classes are considered for correction
    • interval (int): Apply LCC every interval epochs. Defaults to $1$, meaning LCC will be applied every epoch (after warmup).
    • warmup (int): Number of epochs to wait before starting LCC. Defaults to $0$, meaning LCC will start immediately.
    • k (int): Number of nearest neighbors to consider for LCC, and Louvain clustering.
    • pca_dim (int): Samples are reduced to this dimension before constructing the KNN graph. This must be at most the batch size.
    • loss ("exact" or "randomized"): Way the LCC loss is computed from the clustering data. See lcc.correction.loss.
    • ccspc (int): If the loss type is "randomized", then this parameter specified the number of CC samples per cluster to keep as potential correction targets. See also lcc.correction.loss.RandomizedLCCLoss.
    • clustering_method ("louvain" or "peer_pressure"): Clustering algorithm for the dataset of latent representations.
  • ce_weight (float, optional): Weight of the cross-entropy loss in the clustering-CE loss. Ignored if LCC is not applied. Defaults to $1$.
  • image_key (Any, optional): A batch passed to the model can be a tuple (most common) or a dict. This parameter specifies the key to use to retrieve the input tensor.
  • label_key (Any, optional): Analogous to image_key.
  • kwargs: Forwarded to pl.LightningModule
lcc_data: dict[str, LatentClusteringData] | None = None

If LCC is applied, then this is non None and updated at the begining of each epoch. See also full_dataset_latent_clustering.

standard_loss: torch.nn.modules.module.Module

'Standard' loss to use together with LCC.

accuracy_top1: torch.nn.modules.module.Module

Top-1 accuracy metric.

accuracy_top5: torch.nn.modules.module.Module | None = None

Top-5 accuracy metric.

def configure_optimizers(self) -> Any:
194    def configure_optimizers(self) -> Any:
195        optimizer = torch.optim.SGD(self.parameters(), lr=1e-1, momentum=0.9)
196        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
197            optimizer, T_max=50
198        )
199        return {
200            "optimizer": optimizer,
201            "lr_scheduler": {"scheduler": scheduler},
202        }

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.
  • List or Tuple of optimizers.
  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

.. testcode::

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your ~pytorch_lightning.core.LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.
  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default "epoch") in the scheduler configuration, Lightning will call the scheduler's .step() method automatically in case of automatic optimization.
  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.
  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.
  • If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
  • If you need to control how often the optimizer steps, override the optimizer_step() hook.
def forward_intermediate( self, inputs: Union[torch.Tensor, dict[str, torch.Tensor], list[torch.Tensor], Sequence[dict[str, torch.Tensor]]], submodules: list[str], output_dict: dict, keep_gradients: bool = False) -> torch.Tensor | list[torch.Tensor]:
204    def forward_intermediate(
205        self,
206        inputs: Tensor | Batch | list[Tensor] | Sequence[Batch],
207        submodules: list[str],
208        output_dict: dict,
209        keep_gradients: bool = False,
210    ) -> Tensor | list[Tensor]:
211        """
212        Runs the model and collects the output of specified submodules. The
213        intermediate outputs are stored in `output_dict` under the
214        corresponding submodule name. In particular, this method has side
215        effects.
216
217        Args:
218            x (Tensor | Batch | list[Tensor] | list[Batch]): If batched (i.e.
219                `x` is a list), then so is the output of this function and the
220                entries in the `output_dict`
221            submodules (list[str]):
222            output_dict (dict):
223            keep_gradients (bool, optional): If `True`, the tensors in
224                `output_dict` keep their gradients (if they had some on the
225                first place). If `False`, they are detached and moved to the
226                CPU.
227        """
228
229        def maybe_detach(x: Tensor) -> Tensor:
230            return x if keep_gradients else x.detach().cpu()
231
232        def create_hook(key: str) -> Callable[[nn.Module, Any, Any], None]:
233            def hook(_model: nn.Module, _args: Any, out: Any) -> None:
234                if (
235                    isinstance(out, (list, tuple))
236                    and len(out) == 1
237                    and isinstance(out[0], Tensor)
238                ):
239                    out = out[0]
240                elif (  # Special case for ViTs
241                    isinstance(out, (list, tuple))
242                    and len(out) == 2
243                    and isinstance(out[0], Tensor)
244                    and not isinstance(out[1], Tensor)
245                ):
246                    out = out[0]
247                elif not isinstance(out, Tensor):
248                    raise ValueError(
249                        f"Unsupported latent object type: {type(out)}: {out}."
250                    )
251                if batched:
252                    if key not in output_dict:
253                        output_dict[key] = []
254                    output_dict[key].append(maybe_detach(out))
255                else:
256                    output_dict[key] = maybe_detach(out)
257
258            return hook
259
260        batched = isinstance(inputs, (list, tuple))
261        handles: list[RemovableHandle] = []
262        for name in submodules:
263            submodule = self.get_submodule(name)
264            handles.append(submodule.register_forward_hook(create_hook(name)))
265        if batched:
266            logits = [
267                maybe_detach(
268                    self.forward(
269                        batch
270                        if isinstance(batch, Tensor)
271                        else batch[self.hparams["image_key"]]
272                    )
273                )
274                for batch in inputs
275            ]
276        else:
277            logits = maybe_detach(  # type: ignore
278                self.forward(
279                    inputs
280                    if isinstance(inputs, Tensor)
281                    else inputs[self.hparams["image_key"]]
282                )
283            )
284        for h in handles:
285            h.remove()
286        return logits

Runs the model and collects the output of specified submodules. The intermediate outputs are stored in output_dict under the corresponding submodule name. In particular, this method has side effects.

Arguments:
  • x (Tensor | Batch | list[Tensor] | list[Batch]): If batched (i.e. x is a list), then so is the output of this function and the entries in the output_dict
  • submodules (list[str]):
  • output_dict (dict):
  • keep_gradients (bool, optional): If True, the tensors in output_dict keep their gradients (if they had some on the first place). If False, they are detached and moved to the CPU.
@staticmethod
def get_image_processor(model_name: str, **kwargs: Any) -> Callable:
288    @staticmethod
289    def get_image_processor(model_name: str, **kwargs: Any) -> Callable:
290        """
291        Returns an image processor for the model. By defaults, returns the
292        identity function.
293        """
294        return lambda input: input

Returns an image processor for the model. By defaults, returns the identity function.

lcc_submodules: list[str]
296    @property
297    def lcc_submodules(self) -> list[str]:
298        """
299        Returns the list of submodules considered for LCC, whith correct prefix
300        if needed.
301        """
302        return self.hparams.get("lcc_submodules") or []

Returns the list of submodules considered for LCC, whith correct prefix if needed.

def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
304    def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
305        """Just logs all optimizer's learning rate"""
306        log_optimizers_lr(self, sync_dist=True)
307        super().on_train_batch_end(*args, **kwargs)

Just logs all optimizer's learning rate

def on_train_epoch_end(self) -> None:
309    def on_train_epoch_end(self) -> None:
310        """Cleans up training specific temporary attributes"""
311        self.lcc_data = None
312        super().on_train_epoch_end()

Cleans up training specific temporary attributes

def on_train_epoch_start(self) -> None:
314    def on_train_epoch_start(self) -> None:
315        """
316        Performs dataset-wide latent clustering and stores the results in
317        private attribute `BaseClassifier._lc_data`.
318        """
319        # wether to apply LCC this epoch
320        lcc_kwargs = self.hparams.get("lcc_kwargs") or {}
321        do_lcc = (
322            # we are passed warmup (lcc_warmup being None is equivalent to no
323            # warmup)...
324            self.current_epoch >= (lcc_kwargs.get("warmup") or 0)
325            and (
326                # ... and an LCC interval is specified...
327                lcc_kwargs.get("interval") is not None
328                # ... and the current epoch can have LCC done...
329                and self.current_epoch % int(lcc_kwargs.get("interval", 1))
330                == 0
331            )
332            # ... and there are submodule selected for LCC...
333            and self.lcc_submodules
334            # ... and the LCC weight is non-zero
335            and lcc_kwargs.get("weight", 0) > 0
336        )
337        if do_lcc:
338            # The import has to be here to prevent circular imports while having
339            # all the FDLC logic in a separate file
340            from .fdlc import full_dataset_latent_clustering
341
342            with temporary_directory(self) as tmp_path:
343                self.lcc_data = full_dataset_latent_clustering(
344                    model=self,
345                    output_dir=tmp_path,
346                    tqdm_style="console",
347                )
348        super().on_train_epoch_start()

Performs dataset-wide latent clustering and stores the results in private attribute BaseClassifier._lc_data.

def on_train_start(self) -> None:
350    def on_train_start(self) -> None:
351        """
352        Explicitly registers hyperparameters and metrics. You'd think Lightning
353        would do this automatically, but nope.
354        """
355        self.logger.log_hyperparams(  # type: ignore
356            self.hparams,
357            {
358                s + "/" + m: np.nan
359                for s, m in product(
360                    ["train", "val"], ["acc", "loss", "ce", "lcc"]
361                )
362            },
363        )
364        super().on_train_start()

Explicitly registers hyperparameters and metrics. You'd think Lightning would do this automatically, but nope.

def test_step(self, batch: dict[str, torch.Tensor], *_: Any, **__: Any) -> torch.Tensor:
366    def test_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor:
367        """Override from `pl.LightningModule.test_step`."""
368        return self._evaluate(batch, "test")

Override from pl.LightningModule.test_step.

def training_step(self, batch: dict[str, torch.Tensor], *_: Any, **__: Any) -> torch.Tensor:
370    def training_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor:
371        """Override from `pl.LightningModule.training_step`."""
372        return self._evaluate(batch, "train")

Override from pl.LightningModule.training_step.

def validation_step(self, batch: dict[str, torch.Tensor], *_: Any, **__: Any) -> torch.Tensor:
374    def validation_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor:
375        """Override from `pl.LightningModule.validation_step`."""
376        return self._evaluate(batch, "val")

Override from pl.LightningModule.validation_step.

def get_classifier_cls(model_name: str) -> type[BaseClassifier]:
26def get_classifier_cls(model_name: str) -> type[BaseClassifier]:
27    """
28    Returns the classifier class to use for a given model name.
29
30    Args:
31        model_name (str):
32    """
33    if model_name.startswith("timm/"):
34        return TimmClassifier
35    if "/" in model_name:
36        return HuggingFaceClassifier
37    return TorchvisionClassifier

Returns the classifier class to use for a given model name.

Arguments:
  • model_name (str):
class HuggingFaceClassifier(lcc.classifiers.WrappedClassifier):
11class HuggingFaceClassifier(WrappedClassifier):
12    """
13    Pretrained classifier model loaded from the [HuggingFace model
14    hub](https://huggingface.co/models?pipeline_tag=image-classification).
15    """
16
17    def __init__(
18        self,
19        model_name: str,
20        n_classes: int,
21        head_name: str | None = None,
22        **kwargs: Any,
23    ) -> None:
24        """
25        See also:
26            `lcc.classifiers.WrappedClassifier.__init__` and
27            `lcc.classifiers.BaseClassifier.__init__`.
28
29        Args:
30            model_name (str): Model name as in the [HuggingFace model
31                hub](https://huggingface.co/models?pipeline_tag=image-classification).
32                If the model name starts with `timm/`, use
33                `lcc.classifiers.TimmClassifier` instead.
34            n_classes (int): See `lcc.classifiers.WrappedClassifier.__init__`.
35            head_name (str | None, optional): See
36                `lcc.classifiers.WrappedClassifier.__init__`.
37        """
38        if model_name.startswith("timm/"):
39            raise ValueError(
40                "If the model name starts with `timm/`, use "
41                "`lcc.classifiers.TimmClassifier` instead."
42            )
43        model = AutoModelForImageClassification.from_pretrained(model_name)
44        super().__init__(model, n_classes, head_name, **kwargs)
45        self.save_hyperparameters()
46
47    @staticmethod
48    def get_image_processor(
49        model_name: str, **__: Any
50    ) -> Callable[[dict[str, Any]], dict[str, Any]]:
51        """
52        Wraps the HuggingFace `AutoImageProcessor` associated to a given model.
53
54        Args:
55            model_name (str): Model name as in the [HuggingFace model
56                hub](https://huggingface.co/models?pipeline_tag=image-classification).
57                Must not start with `timm/`.
58
59        Returns:
60            A callable that uses a
61            [`transformers.AutoImageProcessor`](https://huggingface.co/docs/transformers/v4.44.2/en/model_doc/auto#transformers.AutoImageProcessor)
62            under the hood.
63        """
64        hf_transorm = AutoImageProcessor.from_pretrained(model_name)
65
66        def _transform(batch: dict[str, Any]) -> dict[str, Any]:
67            return {
68                k: (
69                    hf_transorm(
70                        [img.convert("RGB") for img in v], return_tensors="pt"
71                    )["pixel_values"]
72                    # TODO: pass image_key from DS ↓
73                    if k in ["img", "image", "jpg", "png"]
74                    else v
75                )
76                for k, v in batch.items()
77            }
78
79        return _transform

Pretrained classifier model loaded from the HuggingFace model hub.

HuggingFaceClassifier( model_name: str, n_classes: int, head_name: str | None = None, **kwargs: Any)
17    def __init__(
18        self,
19        model_name: str,
20        n_classes: int,
21        head_name: str | None = None,
22        **kwargs: Any,
23    ) -> None:
24        """
25        See also:
26            `lcc.classifiers.WrappedClassifier.__init__` and
27            `lcc.classifiers.BaseClassifier.__init__`.
28
29        Args:
30            model_name (str): Model name as in the [HuggingFace model
31                hub](https://huggingface.co/models?pipeline_tag=image-classification).
32                If the model name starts with `timm/`, use
33                `lcc.classifiers.TimmClassifier` instead.
34            n_classes (int): See `lcc.classifiers.WrappedClassifier.__init__`.
35            head_name (str | None, optional): See
36                `lcc.classifiers.WrappedClassifier.__init__`.
37        """
38        if model_name.startswith("timm/"):
39            raise ValueError(
40                "If the model name starts with `timm/`, use "
41                "`lcc.classifiers.TimmClassifier` instead."
42            )
43        model = AutoModelForImageClassification.from_pretrained(model_name)
44        super().__init__(model, n_classes, head_name, **kwargs)
45        self.save_hyperparameters()
See also:

lcc.classifiers.WrappedClassifier.__init__ and lcc.classifiers.BaseClassifier.__init__.

Arguments:
@staticmethod
def get_image_processor(model_name: str, **__: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
47    @staticmethod
48    def get_image_processor(
49        model_name: str, **__: Any
50    ) -> Callable[[dict[str, Any]], dict[str, Any]]:
51        """
52        Wraps the HuggingFace `AutoImageProcessor` associated to a given model.
53
54        Args:
55            model_name (str): Model name as in the [HuggingFace model
56                hub](https://huggingface.co/models?pipeline_tag=image-classification).
57                Must not start with `timm/`.
58
59        Returns:
60            A callable that uses a
61            [`transformers.AutoImageProcessor`](https://huggingface.co/docs/transformers/v4.44.2/en/model_doc/auto#transformers.AutoImageProcessor)
62            under the hood.
63        """
64        hf_transorm = AutoImageProcessor.from_pretrained(model_name)
65
66        def _transform(batch: dict[str, Any]) -> dict[str, Any]:
67            return {
68                k: (
69                    hf_transorm(
70                        [img.convert("RGB") for img in v], return_tensors="pt"
71                    )["pixel_values"]
72                    # TODO: pass image_key from DS ↓
73                    if k in ["img", "image", "jpg", "png"]
74                    else v
75                )
76                for k, v in batch.items()
77            }
78
79        return _transform

Wraps the HuggingFace AutoImageProcessor associated to a given model.

Arguments:
Returns:

A callable that uses a transformers.AutoImageProcessor under the hood.

@dataclass
class LatentClusteringData:
30@dataclass
31class LatentClusteringData:
32    """
33    Convenience struct that holds some latent clustering correction data for a
34    given latent space.
35    """
36
37    loss: LCCLoss
38    """
39    The actual LCC loss object (which is callable) that compute the LCC loss for
40    LCC correction.
41    """
42
43    # TODO: Write the y_clst to disk and loading alongside the train dataset, so
44    # that this whole datastructure becomes unnecessary
45    y_clst: np.ndarray
46    """`(N,)` vector of cluster labels."""

Convenience struct that holds some latent clustering correction data for a given latent space.

LatentClusteringData(loss: lcc.correction.LCCLoss, y_clst: numpy.ndarray)

The actual LCC loss object (which is callable) that compute the LCC loss for LCC correction.

y_clst: numpy.ndarray

(N,) vector of cluster labels.

class TimmClassifier(lcc.classifiers.WrappedClassifier):
12class TimmClassifier(WrappedClassifier):
13    """
14    Pretrained classifier model loaded from the [HuggingFace model
15    hub](https://huggingface.co/models?pipeline_tag=image-classification)
16    uploaded by the `timm` team.
17    """
18
19    image_transform: Callable
20
21    def __init__(
22        self,
23        model_name: str,
24        n_classes: int,
25        head_name: str | None = None,
26        pretrained: bool = True,
27        **kwargs: Any,
28    ) -> None:
29        """
30        See also:
31            `lcc.classifiers.WrappedClassifier.__init__` and
32            `lcc.classifiers.BaseClassifier.__init__`.
33
34        Args:
35            model_name (str): Model name as in the [HuggingFace model
36                hub](https://huggingface.co/models?pipeline_tag=image-classification).
37                Must start with `timm/`.
38            n_classes (int): See `lcc.classifiers.WrappedClassifier.__init__`.
39            head_name (str | None, optional): See
40                `lcc.classifiers.WrappedClassifier.__init__`.
41            pretrained (bool, optional): Defaults to `True`.
42        """
43        if not model_name.startswith("timm/"):
44            raise ValueError(
45                "The model isn't a timm model (its name does not start with "
46                "`timm/`). Use `lcc.classifiers.HuggingFaceClassifier` "
47                "instead."
48            )
49        model = timm.create_model(model_name, pretrained=pretrained)
50        super().__init__(model, n_classes, head_name, **kwargs)
51        self.save_hyperparameters()
52
53    @staticmethod
54    def get_image_processor(model_name: str, **kwargs: Any) -> Callable:
55        """
56        Wraps the HuggingFace `AutoImageProcessor` associated to a given model.
57
58        See also:
59            [`timm.create_transform`](https://huggingface.co/docs/timm/reference/data#timm.data.create_transform).
60
61        Args:
62            model_name (str): Model name as in the [HuggingFace model
63                hub](https://huggingface.co/models?pipeline_tag=image-classification).
64                Must start with `timm/`.
65
66        Returns:
67            A callable that uses a [Torchvision
68            transform](https://pytorch.org/vision/0.19/transforms.html) under
69            the hood.
70        """
71        model = timm.create_model(model_name, pretrained=False)
72        conf = timm.data.resolve_model_data_config(model)
73        conf["is_training"], conf["no_aug"] = True, True
74        timm_transform = timm.data.create_transform(**conf)
75
76        def _transform(batch: dict[str, Any]) -> dict[str, Any]:
77            return {
78                k: (
79                    (
80                        [timm_transform(img.convert("RGB")) for img in v]
81                        if isinstance(v, list)
82                        else timm_transform(v)
83                    )
84                    # TODO: pass image_key from DS ↓
85                    if k in ["img", "image", "jpg", "png"]
86                    else v
87                )
88                for k, v in batch.items()
89            }
90
91        return _transform
92
93    def forward(self, inputs: Tensor | Batch, *_: Any, **__: Any) -> Tensor:
94        x: Tensor = (
95            inputs if isinstance(inputs, Tensor) else inputs[self.image_key]
96        )
97        return self.model(x.to(self.device))  # type: ignore

Pretrained classifier model loaded from the HuggingFace model hub uploaded by the timm team.

TimmClassifier( model_name: str, n_classes: int, head_name: str | None = None, pretrained: bool = True, **kwargs: Any)
21    def __init__(
22        self,
23        model_name: str,
24        n_classes: int,
25        head_name: str | None = None,
26        pretrained: bool = True,
27        **kwargs: Any,
28    ) -> None:
29        """
30        See also:
31            `lcc.classifiers.WrappedClassifier.__init__` and
32            `lcc.classifiers.BaseClassifier.__init__`.
33
34        Args:
35            model_name (str): Model name as in the [HuggingFace model
36                hub](https://huggingface.co/models?pipeline_tag=image-classification).
37                Must start with `timm/`.
38            n_classes (int): See `lcc.classifiers.WrappedClassifier.__init__`.
39            head_name (str | None, optional): See
40                `lcc.classifiers.WrappedClassifier.__init__`.
41            pretrained (bool, optional): Defaults to `True`.
42        """
43        if not model_name.startswith("timm/"):
44            raise ValueError(
45                "The model isn't a timm model (its name does not start with "
46                "`timm/`). Use `lcc.classifiers.HuggingFaceClassifier` "
47                "instead."
48            )
49        model = timm.create_model(model_name, pretrained=pretrained)
50        super().__init__(model, n_classes, head_name, **kwargs)
51        self.save_hyperparameters()
See also:

lcc.classifiers.WrappedClassifier.__init__ and lcc.classifiers.BaseClassifier.__init__.

Arguments:
image_transform: Callable
@staticmethod
def get_image_processor(model_name: str, **kwargs: Any) -> Callable:
53    @staticmethod
54    def get_image_processor(model_name: str, **kwargs: Any) -> Callable:
55        """
56        Wraps the HuggingFace `AutoImageProcessor` associated to a given model.
57
58        See also:
59            [`timm.create_transform`](https://huggingface.co/docs/timm/reference/data#timm.data.create_transform).
60
61        Args:
62            model_name (str): Model name as in the [HuggingFace model
63                hub](https://huggingface.co/models?pipeline_tag=image-classification).
64                Must start with `timm/`.
65
66        Returns:
67            A callable that uses a [Torchvision
68            transform](https://pytorch.org/vision/0.19/transforms.html) under
69            the hood.
70        """
71        model = timm.create_model(model_name, pretrained=False)
72        conf = timm.data.resolve_model_data_config(model)
73        conf["is_training"], conf["no_aug"] = True, True
74        timm_transform = timm.data.create_transform(**conf)
75
76        def _transform(batch: dict[str, Any]) -> dict[str, Any]:
77            return {
78                k: (
79                    (
80                        [timm_transform(img.convert("RGB")) for img in v]
81                        if isinstance(v, list)
82                        else timm_transform(v)
83                    )
84                    # TODO: pass image_key from DS ↓
85                    if k in ["img", "image", "jpg", "png"]
86                    else v
87                )
88                for k, v in batch.items()
89            }
90
91        return _transform

Wraps the HuggingFace AutoImageProcessor associated to a given model.

See also:

timm.create_transform.

Arguments:
Returns:

A callable that uses a Torchvision transform under the hood.

def forward( self, inputs: torch.Tensor | dict[str, torch.Tensor], *_: Any, **__: Any) -> torch.Tensor:
93    def forward(self, inputs: Tensor | Batch, *_: Any, **__: Any) -> Tensor:
94        x: Tensor = (
95            inputs if isinstance(inputs, Tensor) else inputs[self.image_key]
96        )
97        return self.model(x.to(self.device))  # type: ignore

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

class TorchvisionClassifier(lcc.classifiers.WrappedClassifier):
12class TorchvisionClassifier(WrappedClassifier):
13    """
14    A torchvision classifier wrapped as a `WrappedClassifier`. See
15    https://pytorch.org/vision/stable/models.html#classification .
16    """
17
18    def __init__(
19        self,
20        model_name: str,
21        n_classes: int,
22        head_name: str | None = None,
23        weights: Any = "DEFAULT",
24        **kwargs: Any,
25    ) -> None:
26        model = get_model(model_name, weights=weights)
27        super().__init__(model, n_classes, head_name, **kwargs)
28        self.save_hyperparameters()
29
30    @staticmethod
31    def get_image_processor(
32        model_name: str, weights: str = "DEFAULT", **__: Any
33    ) -> Callable[[dict[str, Any]], dict[str, Any]]:
34        """
35        Creates an image processor based on the transform object of the model's
36        chosen weights. For example,
37
38            TorchvisionClassifier.get_image_processor("alexnet")
39
40        is analogous to
41
42            get_model_weights("alexnet")["DEFAULT"].transforms()
43        """
44
45        transform = tr.Compose(
46            [tr.RGB(), get_model_weights(model_name)[weights].transforms()]
47        )
48
49        def _processor(batch: dict[str, Any]) -> dict[str, Any]:
50            return {
51                k: (
52                    [transform(img) for img in v]
53                    # TODO: pass image_key from DS ↓
54                    if k in ["img", "image", "jpg", "png"]
55                    else v
56                )
57                for k, v in batch.items()
58            }
59
60        return _processor
TorchvisionClassifier( model_name: str, n_classes: int, head_name: str | None = None, weights: Any = 'DEFAULT', **kwargs: Any)
18    def __init__(
19        self,
20        model_name: str,
21        n_classes: int,
22        head_name: str | None = None,
23        weights: Any = "DEFAULT",
24        **kwargs: Any,
25    ) -> None:
26        model = get_model(model_name, weights=weights)
27        super().__init__(model, n_classes, head_name, **kwargs)
28        self.save_hyperparameters()
See also:

lcc.classifiers.BaseClassifier.__init__.

Arguments:
  • model (nn.Module):
  • n_classes (int): Number of classes in the dataset on which the model will be trained.
  • head_name (str | None, optional): Name of the head submodule in model, which is the fully connected (aka nn.Linear) layer that outputs the logits. If not None, the head is replaced by a new fully connected layer with n_classes output neurons classes. Specify this to fine-tune a pretrained model on a new dataset with the same or a different number of classes. The name of a submodule can be retried by inspecting the output of nn.Module.named_modules or lcc.utils.pretty_print_submodules.
  • logit_key (str | None, optional): If the wrapped model outputs a dict-like object instead of a tensor, this key is used to access the actual logits.
@staticmethod
def get_image_processor( model_name: str, weights: str = 'DEFAULT', **__: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
30    @staticmethod
31    def get_image_processor(
32        model_name: str, weights: str = "DEFAULT", **__: Any
33    ) -> Callable[[dict[str, Any]], dict[str, Any]]:
34        """
35        Creates an image processor based on the transform object of the model's
36        chosen weights. For example,
37
38            TorchvisionClassifier.get_image_processor("alexnet")
39
40        is analogous to
41
42            get_model_weights("alexnet")["DEFAULT"].transforms()
43        """
44
45        transform = tr.Compose(
46            [tr.RGB(), get_model_weights(model_name)[weights].transforms()]
47        )
48
49        def _processor(batch: dict[str, Any]) -> dict[str, Any]:
50            return {
51                k: (
52                    [transform(img) for img in v]
53                    # TODO: pass image_key from DS ↓
54                    if k in ["img", "image", "jpg", "png"]
55                    else v
56                )
57                for k, v in batch.items()
58            }
59
60        return _processor

Creates an image processor based on the transform object of the model's chosen weights. For example,

TorchvisionClassifier.get_image_processor("alexnet")

is analogous to

get_model_weights("alexnet")["DEFAULT"].transforms()
def validate_lcc_kwargs(lcc_kwargs: dict[str, typing.Any] | None) -> None:
 79def validate_lcc_kwargs(lcc_kwargs: dict[str, Any] | None) -> None:
 80    """
 81    Makes sure that an LCC hyperparameter dict is valid. Used in constructors of
 82    LCC enabled classifiers.
 83
 84    Args:
 85        lcc_kwargs (dict[str, Any] | None):
 86    """
 87    if not lcc_kwargs:
 88        return
 89    if (x := lcc_kwargs.get("weight", 1)) <= 0:
 90        raise ValueError(f"LCC weight must be positive, got {x}")
 91    if (x := lcc_kwargs.get("interval", 1)) < 1:
 92        raise ValueError(f"LCC interval must be at least 1, got {x}")
 93    if (x := lcc_kwargs.get("warmup", 0)) < 0:
 94        raise ValueError(f"LCC warmup must be at least 0, got {x}")
 95    if (x := lcc_kwargs.get("class_selection")) not in LCC_CLASS_SELECTIONS + [
 96        None
 97    ]:
 98        raise ValueError(
 99            f"Invalid class selection policy '{x}'. Available policies are: "
100            + ", ".join(map(lambda a: f"'{a}'", LCC_CLASS_SELECTIONS))
101            + ", or `None`"
102        )
103    LCC_LOSS_TYPES = ["exact", "randomized"]
104    if (x := lcc_kwargs.get("loss", "exact")) not in LCC_LOSS_TYPES:
105        raise ValueError(
106            f"Invalid LCC loss type '{x}'. Available types are: "
107            + ", ".join(map(lambda a: f"'{a}'", LCC_LOSS_TYPES))
108        )
109    LCC_CLUSTERING_METHODS = ["louvain", "peer_pressure"]
110    if (
111        x := lcc_kwargs.get("clustering_method", "louvain")
112    ) not in LCC_CLUSTERING_METHODS:
113        raise ValueError(
114            f"Invalid latent clustering method '{x}'. Available methods are: "
115            + ", ".join(map(lambda a: f"'{a}'", LCC_CLUSTERING_METHODS))
116        )

Makes sure that an LCC hyperparameter dict is valid. Used in constructors of LCC enabled classifiers.

Arguments:
  • lcc_kwargs (dict[str, Any] | None):
class WrappedClassifier(lcc.classifiers.BaseClassifier):
13class WrappedClassifier(BaseClassifier):
14    """
15    An image classifier model
16    ([`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html))
17    wrapped inside a `lcc.classifiers.BaseClassifier`.
18    """
19
20    model: nn.Module
21
22    def __init__(
23        self,
24        model: nn.Module,
25        n_classes: int,
26        head_name: str | None = None,
27        logit_key: str | None = None,
28        **kwargs: Any,
29    ) -> None:
30        """
31        See also:
32            `lcc.classifiers.BaseClassifier.__init__`.
33
34        Args:
35            model (nn.Module):
36            n_classes (int): Number of classes in the dataset on which the model
37                will be trained.
38            head_name (str | None, optional): Name of the head submodule in
39                `model`, which is the fully connected (aka
40                [`nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html))
41                layer that outputs the logits. If not `None`, the head is
42                replaced by a new fully connected layer with `n_classes` output
43                neurons classes. Specify this to fine-tune a pretrained model
44                on a new dataset with the same or a different number of
45                classes. The name of a submodule can be retried by inspecting
46                the output of `nn.Module.named_modules` or
47                `lcc.utils.pretty_print_submodules`.
48            logit_key (str | None, optional): If the wrapped model outputs a
49                dict-like object instead of a tensor, this key is used to access
50                the actual logits.
51        """
52        self.save_hyperparameters(ignore=["model"])
53        super().__init__(n_classes, **kwargs)
54        self.model = model
55        if head_name:
56            replace_head(self.model, head_name, n_classes)
57
58    def forward(self, inputs: Tensor | Batch, *_: Any, **__: Any) -> Tensor:
59        image_key = self.hparams["image_key"]
60        logit_key = self.hparams.get("logit_key")
61        x: Tensor = inputs if isinstance(inputs, Tensor) else inputs[image_key]
62        output = self.model(x.to(self.device))
63        return output if logit_key is None else output[logit_key]  # type: ignore
64
65    @property
66    def lcc_submodules(self) -> list[str]:
67        """
68        Returns the list of submodules considered for LCC, whith correct prefix
69        if needed.
70        """
71        return (
72            []
73            if not self.hparams["lcc_submodules"]
74            else [
75                (sm if sm.startswith("model.") else "model." + sm)
76                for sm in self.hparams["lcc_submodules"]
77            ]
78        )
79
80    def on_train_start(self) -> None:
81        self.model.train()
82        super().on_train_start()

An image classifier model (torch.nn.Module) wrapped inside a lcc.classifiers.BaseClassifier.

WrappedClassifier( model: torch.nn.modules.module.Module, n_classes: int, head_name: str | None = None, logit_key: str | None = None, **kwargs: Any)
22    def __init__(
23        self,
24        model: nn.Module,
25        n_classes: int,
26        head_name: str | None = None,
27        logit_key: str | None = None,
28        **kwargs: Any,
29    ) -> None:
30        """
31        See also:
32            `lcc.classifiers.BaseClassifier.__init__`.
33
34        Args:
35            model (nn.Module):
36            n_classes (int): Number of classes in the dataset on which the model
37                will be trained.
38            head_name (str | None, optional): Name of the head submodule in
39                `model`, which is the fully connected (aka
40                [`nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html))
41                layer that outputs the logits. If not `None`, the head is
42                replaced by a new fully connected layer with `n_classes` output
43                neurons classes. Specify this to fine-tune a pretrained model
44                on a new dataset with the same or a different number of
45                classes. The name of a submodule can be retried by inspecting
46                the output of `nn.Module.named_modules` or
47                `lcc.utils.pretty_print_submodules`.
48            logit_key (str | None, optional): If the wrapped model outputs a
49                dict-like object instead of a tensor, this key is used to access
50                the actual logits.
51        """
52        self.save_hyperparameters(ignore=["model"])
53        super().__init__(n_classes, **kwargs)
54        self.model = model
55        if head_name:
56            replace_head(self.model, head_name, n_classes)
See also:

lcc.classifiers.BaseClassifier.__init__.

Arguments:
  • model (nn.Module):
  • n_classes (int): Number of classes in the dataset on which the model will be trained.
  • head_name (str | None, optional): Name of the head submodule in model, which is the fully connected (aka nn.Linear) layer that outputs the logits. If not None, the head is replaced by a new fully connected layer with n_classes output neurons classes. Specify this to fine-tune a pretrained model on a new dataset with the same or a different number of classes. The name of a submodule can be retried by inspecting the output of nn.Module.named_modules or lcc.utils.pretty_print_submodules.
  • logit_key (str | None, optional): If the wrapped model outputs a dict-like object instead of a tensor, this key is used to access the actual logits.
model: torch.nn.modules.module.Module
def forward( self, inputs: torch.Tensor | dict[str, torch.Tensor], *_: Any, **__: Any) -> torch.Tensor:
58    def forward(self, inputs: Tensor | Batch, *_: Any, **__: Any) -> Tensor:
59        image_key = self.hparams["image_key"]
60        logit_key = self.hparams.get("logit_key")
61        x: Tensor = inputs if isinstance(inputs, Tensor) else inputs[image_key]
62        output = self.model(x.to(self.device))
63        return output if logit_key is None else output[logit_key]  # type: ignore

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

lcc_submodules: list[str]
65    @property
66    def lcc_submodules(self) -> list[str]:
67        """
68        Returns the list of submodules considered for LCC, whith correct prefix
69        if needed.
70        """
71        return (
72            []
73            if not self.hparams["lcc_submodules"]
74            else [
75                (sm if sm.startswith("model.") else "model." + sm)
76                for sm in self.hparams["lcc_submodules"]
77            ]
78        )

Returns the list of submodules considered for LCC, whith correct prefix if needed.

def on_train_start(self) -> None:
80    def on_train_start(self) -> None:
81        self.model.train()
82        super().on_train_start()

Explicitly registers hyperparameters and metrics. You'd think Lightning would do this automatically, but nope.