lcc.classifiers
Classifier models and related stuff
1"""Classifier models and related stuff""" 2 3from .base import ( 4 BaseClassifier, 5 LatentClusteringData, 6 validate_lcc_kwargs, 7) 8from .huggingface import HuggingFaceClassifier 9from .timm import TimmClassifier 10from .torchvision import TorchvisionClassifier 11from .wrapped import WrappedClassifier 12 13__all__ = [ 14 "BaseClassifier", 15 "get_classifier_cls", 16 "HuggingFaceClassifier", 17 "LatentClusteringData", 18 "TimmClassifier", 19 "TorchvisionClassifier", 20 "validate_lcc_kwargs", 21 "WrappedClassifier", 22] 23 24 25def get_classifier_cls(model_name: str) -> type[BaseClassifier]: 26 """ 27 Returns the classifier class to use for a given model name. 28 29 Args: 30 model_name (str): 31 """ 32 if model_name.startswith("timm/"): 33 return TimmClassifier 34 if "/" in model_name: 35 return HuggingFaceClassifier 36 return TorchvisionClassifier
49class BaseClassifier(pl.LightningModule): 50 """ 51 Base image classifier class that supports LCC. 52 53 Warning: 54 When subclassing this, remember that the forward method must be able to 55 deal with either `Tensor` or `Batch` inputs, and must return a logit 56 `Tensor`. 57 """ 58 59 lcc_data: dict[str, LatentClusteringData] | None = None 60 """ 61 If LCC is applied, then this is non `None` and updated at the begining of 62 each epoch. See also `full_dataset_latent_clustering`. 63 """ 64 65 standard_loss: nn.Module 66 """'Standard' loss to use together with LCC.""" 67 68 accuracy_top1: nn.Module 69 """Top-1 accuracy metric.""" 70 71 accuracy_top5: nn.Module | None = None 72 """Top-5 accuracy metric.""" 73 74 def __init__( 75 self, 76 n_classes: int, 77 lcc_submodules: list[str] | None = None, 78 lcc_kwargs: dict | None = None, 79 ce_weight: float = 1, 80 image_key: Any = 0, 81 label_key: Any = 1, 82 **kwargs: Any, 83 ) -> None: 84 """ 85 Args: 86 n_classes (int): 87 lcc_submodules (list[str] | None, optional): Submodules to consider 88 for the latent correction loss. If `None` or `[]`, LCC is not 89 performed 90 lcc_kwargs (dict | None, optional): Optional parameters for LCC. 91 Expected entries (all optional) are: 92 * **weight (float):** Defaults to $10^{-4}$ 93 * **class_selection 94 (`lcc.correction.LCCClassSelection` | None):** Defaults to 95 `None`, which means all classes are considered for 96 correction 97 * **interval (int):** Apply LCC every `interval` epochs. 98 Defaults to $1$, meaning LCC will be applied every epoch 99 (after warmup). 100 * **warmup (int):** Number of epochs to wait before 101 starting LCC. Defaults to $0$, meaning LCC will start 102 immediately. 103 * **k (int):** Number of nearest neighbors to consider for LCC, 104 and Louvain clustering. 105 * **pca_dim (int):** Samples are reduced to this dimension 106 before constructing the KNN graph. This must be at most the 107 batch size. 108 * **loss (`"exact"` or `"randomized"`)**: Way the LCC loss is 109 computed from the clustering data. See `lcc.correction.loss`. 110 * **ccspc (int)**: If the loss type is `"randomized"`, then this 111 parameter specified the number of CC samples per cluster to 112 keep as potential correction targets. See also 113 `lcc.correction.loss.RandomizedLCCLoss`. 114 * **clustering_method (`"louvain"` or `"peer_pressure"`)**: 115 Clustering algorithm for the dataset of latent 116 representations. 117 ce_weight (float, optional): Weight of the cross-entropy loss in the 118 clustering-CE loss. Ignored if LCC is not applied. Defaults to 119 $1$. 120 image_key (Any, optional): A batch passed to the model can be a 121 tuple (most common) or a dict. This parameter specifies the key 122 to use to retrieve the input tensor. 123 label_key (Any, optional): Analogous to `image_key`. 124 kwargs: Forwarded to 125 [`pl.LightningModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#) 126 """ 127 super().__init__(**kwargs) 128 self.save_hyperparameters( 129 "ce_weight", 130 "image_key", 131 "label_key", 132 "lcc_kwargs", 133 "lcc_submodules", 134 "n_classes", 135 ) 136 if lcc_submodules: 137 validate_lcc_kwargs(lcc_kwargs) 138 self.standard_loss = torch.nn.CrossEntropyLoss() 139 acc_kw: dict[str, Any] = { 140 "task": "multiclass", 141 "num_classes": n_classes, 142 "average": "micro", 143 } 144 self.accuracy_top1 = tm.Accuracy(top_k=1, **acc_kw) # type: ignore 145 if n_classes > 5: 146 self.accuracy_top5 = tm.Accuracy(top_k=5, **acc_kw) # type: ignore 147 148 def _evaluate(self, batch: Batch, stage: str | None = None) -> Tensor: 149 """Self-explanatory""" 150 image_key = self.hparams["image_key"] 151 label_key = self.hparams["label_key"] 152 x, y = batch[image_key], batch[label_key].to(self.device) 153 latent: dict[str, Tensor] = {} 154 logits = self.forward_intermediate( 155 x, self.lcc_submodules, latent, keep_gradients=True 156 ) 157 assert isinstance(logits, Tensor) 158 loss_ce = self.standard_loss(logits, y) 159 if self.lcc_data and stage == "train": 160 idx = to_array(batch["_idx"]) 161 _losses = [ 162 self.lcc_data[sm].loss( 163 z, y_true=y, y_clst=self.lcc_data[sm].y_clst[idx] 164 ) 165 for sm, z in latent.items() 166 ] 167 loss_lcc = ( 168 torch.stack(_losses).mean() 169 if _losses 170 else torch.tensor(0.0, requires_grad=True) 171 # ↑ actually need grad? 172 ) 173 lcc_weight = self.hparams.get("lcc_kwargs", {}).get("weight", 1e-4) 174 loss = self.hparams["ce_weight"] * loss_ce + lcc_weight * loss_lcc 175 self.log(f"{stage}/lcc", loss_lcc, sync_dist=True) 176 else: 177 loss = loss_ce 178 if stage: 179 d = { 180 f"{stage}/loss": loss, 181 f"{stage}/ce": loss_ce, 182 } 183 if self.accuracy_top5 is not None: 184 d[f"{stage}/acc5"] = self.accuracy_top5(logits, y) 185 self.log_dict(d, sync_dist=True) 186 self.log( 187 f"{stage}/acc", 188 self.accuracy_top1(logits, y), 189 prog_bar=True, 190 sync_dist=True, 191 ) 192 return loss # type: ignore 193 194 def configure_optimizers(self) -> Any: 195 optimizer = torch.optim.SGD(self.parameters(), lr=1e-1, momentum=0.9) 196 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 197 optimizer, T_max=50 198 ) 199 return { 200 "optimizer": optimizer, 201 "lr_scheduler": {"scheduler": scheduler}, 202 } 203 204 def forward_intermediate( 205 self, 206 inputs: Tensor | Batch | list[Tensor] | Sequence[Batch], 207 submodules: list[str], 208 output_dict: dict, 209 keep_gradients: bool = False, 210 ) -> Tensor | list[Tensor]: 211 """ 212 Runs the model and collects the output of specified submodules. The 213 intermediate outputs are stored in `output_dict` under the 214 corresponding submodule name. In particular, this method has side 215 effects. 216 217 Args: 218 x (Tensor | Batch | list[Tensor] | list[Batch]): If batched (i.e. 219 `x` is a list), then so is the output of this function and the 220 entries in the `output_dict` 221 submodules (list[str]): 222 output_dict (dict): 223 keep_gradients (bool, optional): If `True`, the tensors in 224 `output_dict` keep their gradients (if they had some on the 225 first place). If `False`, they are detached and moved to the 226 CPU. 227 """ 228 229 def maybe_detach(x: Tensor) -> Tensor: 230 return x if keep_gradients else x.detach().cpu() 231 232 def create_hook(key: str) -> Callable[[nn.Module, Any, Any], None]: 233 def hook(_model: nn.Module, _args: Any, out: Any) -> None: 234 if ( 235 isinstance(out, (list, tuple)) 236 and len(out) == 1 237 and isinstance(out[0], Tensor) 238 ): 239 out = out[0] 240 elif ( # Special case for ViTs 241 isinstance(out, (list, tuple)) 242 and len(out) == 2 243 and isinstance(out[0], Tensor) 244 and not isinstance(out[1], Tensor) 245 ): 246 out = out[0] 247 elif not isinstance(out, Tensor): 248 raise ValueError( 249 f"Unsupported latent object type: {type(out)}: {out}." 250 ) 251 if batched: 252 if key not in output_dict: 253 output_dict[key] = [] 254 output_dict[key].append(maybe_detach(out)) 255 else: 256 output_dict[key] = maybe_detach(out) 257 258 return hook 259 260 batched = isinstance(inputs, (list, tuple)) 261 handles: list[RemovableHandle] = [] 262 for name in submodules: 263 submodule = self.get_submodule(name) 264 handles.append(submodule.register_forward_hook(create_hook(name))) 265 if batched: 266 logits = [ 267 maybe_detach( 268 self.forward( 269 batch 270 if isinstance(batch, Tensor) 271 else batch[self.hparams["image_key"]] 272 ) 273 ) 274 for batch in inputs 275 ] 276 else: 277 logits = maybe_detach( # type: ignore 278 self.forward( 279 inputs 280 if isinstance(inputs, Tensor) 281 else inputs[self.hparams["image_key"]] 282 ) 283 ) 284 for h in handles: 285 h.remove() 286 return logits 287 288 @staticmethod 289 def get_image_processor(model_name: str, **kwargs: Any) -> Callable: 290 """ 291 Returns an image processor for the model. By defaults, returns the 292 identity function. 293 """ 294 return lambda input: input 295 296 @property 297 def lcc_submodules(self) -> list[str]: 298 """ 299 Returns the list of submodules considered for LCC, whith correct prefix 300 if needed. 301 """ 302 return self.hparams.get("lcc_submodules") or [] 303 304 def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None: 305 """Just logs all optimizer's learning rate""" 306 log_optimizers_lr(self, sync_dist=True) 307 super().on_train_batch_end(*args, **kwargs) 308 309 def on_train_epoch_end(self) -> None: 310 """Cleans up training specific temporary attributes""" 311 self.lcc_data = None 312 super().on_train_epoch_end() 313 314 def on_train_epoch_start(self) -> None: 315 """ 316 Performs dataset-wide latent clustering and stores the results in 317 private attribute `BaseClassifier._lc_data`. 318 """ 319 # wether to apply LCC this epoch 320 lcc_kwargs = self.hparams.get("lcc_kwargs") or {} 321 do_lcc = ( 322 # we are passed warmup (lcc_warmup being None is equivalent to no 323 # warmup)... 324 self.current_epoch >= (lcc_kwargs.get("warmup") or 0) 325 and ( 326 # ... and an LCC interval is specified... 327 lcc_kwargs.get("interval") is not None 328 # ... and the current epoch can have LCC done... 329 and self.current_epoch % int(lcc_kwargs.get("interval", 1)) 330 == 0 331 ) 332 # ... and there are submodule selected for LCC... 333 and self.lcc_submodules 334 # ... and the LCC weight is non-zero 335 and lcc_kwargs.get("weight", 0) > 0 336 ) 337 if do_lcc: 338 # The import has to be here to prevent circular imports while having 339 # all the FDLC logic in a separate file 340 from .fdlc import full_dataset_latent_clustering 341 342 with temporary_directory(self) as tmp_path: 343 self.lcc_data = full_dataset_latent_clustering( 344 model=self, 345 output_dir=tmp_path, 346 tqdm_style="console", 347 ) 348 super().on_train_epoch_start() 349 350 def on_train_start(self) -> None: 351 """ 352 Explicitly registers hyperparameters and metrics. You'd think Lightning 353 would do this automatically, but nope. 354 """ 355 self.logger.log_hyperparams( # type: ignore 356 self.hparams, 357 { 358 s + "/" + m: np.nan 359 for s, m in product( 360 ["train", "val"], ["acc", "loss", "ce", "lcc"] 361 ) 362 }, 363 ) 364 super().on_train_start() 365 366 def test_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor: 367 """Override from `pl.LightningModule.test_step`.""" 368 return self._evaluate(batch, "test") 369 370 def training_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor: 371 """Override from `pl.LightningModule.training_step`.""" 372 return self._evaluate(batch, "train") 373 374 def validation_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor: 375 """Override from `pl.LightningModule.validation_step`.""" 376 return self._evaluate(batch, "val")
Base image classifier class that supports LCC.
Warning:
When subclassing this, remember that the forward method must be able to deal with either
Tensor
orBatch
inputs, and must return a logitTensor
.
74 def __init__( 75 self, 76 n_classes: int, 77 lcc_submodules: list[str] | None = None, 78 lcc_kwargs: dict | None = None, 79 ce_weight: float = 1, 80 image_key: Any = 0, 81 label_key: Any = 1, 82 **kwargs: Any, 83 ) -> None: 84 """ 85 Args: 86 n_classes (int): 87 lcc_submodules (list[str] | None, optional): Submodules to consider 88 for the latent correction loss. If `None` or `[]`, LCC is not 89 performed 90 lcc_kwargs (dict | None, optional): Optional parameters for LCC. 91 Expected entries (all optional) are: 92 * **weight (float):** Defaults to $10^{-4}$ 93 * **class_selection 94 (`lcc.correction.LCCClassSelection` | None):** Defaults to 95 `None`, which means all classes are considered for 96 correction 97 * **interval (int):** Apply LCC every `interval` epochs. 98 Defaults to $1$, meaning LCC will be applied every epoch 99 (after warmup). 100 * **warmup (int):** Number of epochs to wait before 101 starting LCC. Defaults to $0$, meaning LCC will start 102 immediately. 103 * **k (int):** Number of nearest neighbors to consider for LCC, 104 and Louvain clustering. 105 * **pca_dim (int):** Samples are reduced to this dimension 106 before constructing the KNN graph. This must be at most the 107 batch size. 108 * **loss (`"exact"` or `"randomized"`)**: Way the LCC loss is 109 computed from the clustering data. See `lcc.correction.loss`. 110 * **ccspc (int)**: If the loss type is `"randomized"`, then this 111 parameter specified the number of CC samples per cluster to 112 keep as potential correction targets. See also 113 `lcc.correction.loss.RandomizedLCCLoss`. 114 * **clustering_method (`"louvain"` or `"peer_pressure"`)**: 115 Clustering algorithm for the dataset of latent 116 representations. 117 ce_weight (float, optional): Weight of the cross-entropy loss in the 118 clustering-CE loss. Ignored if LCC is not applied. Defaults to 119 $1$. 120 image_key (Any, optional): A batch passed to the model can be a 121 tuple (most common) or a dict. This parameter specifies the key 122 to use to retrieve the input tensor. 123 label_key (Any, optional): Analogous to `image_key`. 124 kwargs: Forwarded to 125 [`pl.LightningModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#) 126 """ 127 super().__init__(**kwargs) 128 self.save_hyperparameters( 129 "ce_weight", 130 "image_key", 131 "label_key", 132 "lcc_kwargs", 133 "lcc_submodules", 134 "n_classes", 135 ) 136 if lcc_submodules: 137 validate_lcc_kwargs(lcc_kwargs) 138 self.standard_loss = torch.nn.CrossEntropyLoss() 139 acc_kw: dict[str, Any] = { 140 "task": "multiclass", 141 "num_classes": n_classes, 142 "average": "micro", 143 } 144 self.accuracy_top1 = tm.Accuracy(top_k=1, **acc_kw) # type: ignore 145 if n_classes > 5: 146 self.accuracy_top5 = tm.Accuracy(top_k=5, **acc_kw) # type: ignore
Arguments:
- n_classes (int):
- lcc_submodules (list[str] | None, optional): Submodules to consider
for the latent correction loss. If
None
or[]
, LCC is not performed - lcc_kwargs (dict | None, optional): Optional parameters for LCC.
Expected entries (all optional) are:
- weight (float): Defaults to $10^{-4}$
- class_selection
(
lcc.correction.LCCClassSelection
| None): Defaults toNone
, which means all classes are considered for correction - interval (int): Apply LCC every
interval
epochs. Defaults to $1$, meaning LCC will be applied every epoch (after warmup). - warmup (int): Number of epochs to wait before starting LCC. Defaults to $0$, meaning LCC will start immediately.
- k (int): Number of nearest neighbors to consider for LCC, and Louvain clustering.
- pca_dim (int): Samples are reduced to this dimension before constructing the KNN graph. This must be at most the batch size.
- loss (
"exact"
or"randomized"
): Way the LCC loss is computed from the clustering data. Seelcc.correction.loss
. - ccspc (int): If the loss type is
"randomized"
, then this parameter specified the number of CC samples per cluster to keep as potential correction targets. See alsolcc.correction.loss.RandomizedLCCLoss
. - clustering_method (
"louvain"
or"peer_pressure"
): Clustering algorithm for the dataset of latent representations.
- ce_weight (float, optional): Weight of the cross-entropy loss in the clustering-CE loss. Ignored if LCC is not applied. Defaults to $1$.
- image_key (Any, optional): A batch passed to the model can be a tuple (most common) or a dict. This parameter specifies the key to use to retrieve the input tensor.
- label_key (Any, optional): Analogous to
image_key
. - kwargs: Forwarded to
pl.LightningModule
If LCC is applied, then this is non None
and updated at the begining of
each epoch. See also full_dataset_latent_clustering
.
194 def configure_optimizers(self) -> Any: 195 optimizer = torch.optim.SGD(self.parameters(), lr=1e-1, momentum=0.9) 196 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 197 optimizer, T_max=50 198 ) 199 return { 200 "optimizer": optimizer, 201 "lr_scheduler": {"scheduler": scheduler}, 202 }
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
Return:
Any of these 6 options.
- Single optimizer.
- List or Tuple of optimizers.
- Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).- Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.- None - Fit will run without any optimizer.
The lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the
torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that the
lr_scheduler_config
contains the keyword "monitor"
set to the metric name that the scheduler
should be conditioned on.
.. testcode::
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
optimizer = Adam(...)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(optimizer, ...),
"monitor": "metric_to_track",
"frequency": "indicates how often the metric is updated",
# If "monitor" references validation metrics, then "frequency" should be set to a
# multiple of "trainer.check_val_every_n_epoch".
},
}
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in your ~pytorch_lightning.core.LightningModule
.
Note:
Some things to know:
- Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.- If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default "epoch") in the scheduler configuration, Lightning will call the scheduler's.step()
method automatically in case of automatic optimization.- If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.- If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
- If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
204 def forward_intermediate( 205 self, 206 inputs: Tensor | Batch | list[Tensor] | Sequence[Batch], 207 submodules: list[str], 208 output_dict: dict, 209 keep_gradients: bool = False, 210 ) -> Tensor | list[Tensor]: 211 """ 212 Runs the model and collects the output of specified submodules. The 213 intermediate outputs are stored in `output_dict` under the 214 corresponding submodule name. In particular, this method has side 215 effects. 216 217 Args: 218 x (Tensor | Batch | list[Tensor] | list[Batch]): If batched (i.e. 219 `x` is a list), then so is the output of this function and the 220 entries in the `output_dict` 221 submodules (list[str]): 222 output_dict (dict): 223 keep_gradients (bool, optional): If `True`, the tensors in 224 `output_dict` keep their gradients (if they had some on the 225 first place). If `False`, they are detached and moved to the 226 CPU. 227 """ 228 229 def maybe_detach(x: Tensor) -> Tensor: 230 return x if keep_gradients else x.detach().cpu() 231 232 def create_hook(key: str) -> Callable[[nn.Module, Any, Any], None]: 233 def hook(_model: nn.Module, _args: Any, out: Any) -> None: 234 if ( 235 isinstance(out, (list, tuple)) 236 and len(out) == 1 237 and isinstance(out[0], Tensor) 238 ): 239 out = out[0] 240 elif ( # Special case for ViTs 241 isinstance(out, (list, tuple)) 242 and len(out) == 2 243 and isinstance(out[0], Tensor) 244 and not isinstance(out[1], Tensor) 245 ): 246 out = out[0] 247 elif not isinstance(out, Tensor): 248 raise ValueError( 249 f"Unsupported latent object type: {type(out)}: {out}." 250 ) 251 if batched: 252 if key not in output_dict: 253 output_dict[key] = [] 254 output_dict[key].append(maybe_detach(out)) 255 else: 256 output_dict[key] = maybe_detach(out) 257 258 return hook 259 260 batched = isinstance(inputs, (list, tuple)) 261 handles: list[RemovableHandle] = [] 262 for name in submodules: 263 submodule = self.get_submodule(name) 264 handles.append(submodule.register_forward_hook(create_hook(name))) 265 if batched: 266 logits = [ 267 maybe_detach( 268 self.forward( 269 batch 270 if isinstance(batch, Tensor) 271 else batch[self.hparams["image_key"]] 272 ) 273 ) 274 for batch in inputs 275 ] 276 else: 277 logits = maybe_detach( # type: ignore 278 self.forward( 279 inputs 280 if isinstance(inputs, Tensor) 281 else inputs[self.hparams["image_key"]] 282 ) 283 ) 284 for h in handles: 285 h.remove() 286 return logits
Runs the model and collects the output of specified submodules. The
intermediate outputs are stored in output_dict
under the
corresponding submodule name. In particular, this method has side
effects.
Arguments:
- x (Tensor | Batch | list[Tensor] | list[Batch]): If batched (i.e.
x
is a list), then so is the output of this function and the entries in theoutput_dict
- submodules (list[str]):
- output_dict (dict):
- keep_gradients (bool, optional): If
True
, the tensors inoutput_dict
keep their gradients (if they had some on the first place). IfFalse
, they are detached and moved to the CPU.
288 @staticmethod 289 def get_image_processor(model_name: str, **kwargs: Any) -> Callable: 290 """ 291 Returns an image processor for the model. By defaults, returns the 292 identity function. 293 """ 294 return lambda input: input
Returns an image processor for the model. By defaults, returns the identity function.
296 @property 297 def lcc_submodules(self) -> list[str]: 298 """ 299 Returns the list of submodules considered for LCC, whith correct prefix 300 if needed. 301 """ 302 return self.hparams.get("lcc_submodules") or []
Returns the list of submodules considered for LCC, whith correct prefix if needed.
304 def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None: 305 """Just logs all optimizer's learning rate""" 306 log_optimizers_lr(self, sync_dist=True) 307 super().on_train_batch_end(*args, **kwargs)
Just logs all optimizer's learning rate
309 def on_train_epoch_end(self) -> None: 310 """Cleans up training specific temporary attributes""" 311 self.lcc_data = None 312 super().on_train_epoch_end()
Cleans up training specific temporary attributes
314 def on_train_epoch_start(self) -> None: 315 """ 316 Performs dataset-wide latent clustering and stores the results in 317 private attribute `BaseClassifier._lc_data`. 318 """ 319 # wether to apply LCC this epoch 320 lcc_kwargs = self.hparams.get("lcc_kwargs") or {} 321 do_lcc = ( 322 # we are passed warmup (lcc_warmup being None is equivalent to no 323 # warmup)... 324 self.current_epoch >= (lcc_kwargs.get("warmup") or 0) 325 and ( 326 # ... and an LCC interval is specified... 327 lcc_kwargs.get("interval") is not None 328 # ... and the current epoch can have LCC done... 329 and self.current_epoch % int(lcc_kwargs.get("interval", 1)) 330 == 0 331 ) 332 # ... and there are submodule selected for LCC... 333 and self.lcc_submodules 334 # ... and the LCC weight is non-zero 335 and lcc_kwargs.get("weight", 0) > 0 336 ) 337 if do_lcc: 338 # The import has to be here to prevent circular imports while having 339 # all the FDLC logic in a separate file 340 from .fdlc import full_dataset_latent_clustering 341 342 with temporary_directory(self) as tmp_path: 343 self.lcc_data = full_dataset_latent_clustering( 344 model=self, 345 output_dir=tmp_path, 346 tqdm_style="console", 347 ) 348 super().on_train_epoch_start()
Performs dataset-wide latent clustering and stores the results in
private attribute BaseClassifier._lc_data
.
350 def on_train_start(self) -> None: 351 """ 352 Explicitly registers hyperparameters and metrics. You'd think Lightning 353 would do this automatically, but nope. 354 """ 355 self.logger.log_hyperparams( # type: ignore 356 self.hparams, 357 { 358 s + "/" + m: np.nan 359 for s, m in product( 360 ["train", "val"], ["acc", "loss", "ce", "lcc"] 361 ) 362 }, 363 ) 364 super().on_train_start()
Explicitly registers hyperparameters and metrics. You'd think Lightning would do this automatically, but nope.
366 def test_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor: 367 """Override from `pl.LightningModule.test_step`.""" 368 return self._evaluate(batch, "test")
Override from pl.LightningModule.test_step
.
370 def training_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor: 371 """Override from `pl.LightningModule.training_step`.""" 372 return self._evaluate(batch, "train")
Override from pl.LightningModule.training_step
.
374 def validation_step(self, batch: Batch, *_: Any, **__: Any) -> Tensor: 375 """Override from `pl.LightningModule.validation_step`.""" 376 return self._evaluate(batch, "val")
Override from pl.LightningModule.validation_step
.
26def get_classifier_cls(model_name: str) -> type[BaseClassifier]: 27 """ 28 Returns the classifier class to use for a given model name. 29 30 Args: 31 model_name (str): 32 """ 33 if model_name.startswith("timm/"): 34 return TimmClassifier 35 if "/" in model_name: 36 return HuggingFaceClassifier 37 return TorchvisionClassifier
Returns the classifier class to use for a given model name.
Arguments:
- model_name (str):
11class HuggingFaceClassifier(WrappedClassifier): 12 """ 13 Pretrained classifier model loaded from the [HuggingFace model 14 hub](https://huggingface.co/models?pipeline_tag=image-classification). 15 """ 16 17 def __init__( 18 self, 19 model_name: str, 20 n_classes: int, 21 head_name: str | None = None, 22 **kwargs: Any, 23 ) -> None: 24 """ 25 See also: 26 `lcc.classifiers.WrappedClassifier.__init__` and 27 `lcc.classifiers.BaseClassifier.__init__`. 28 29 Args: 30 model_name (str): Model name as in the [HuggingFace model 31 hub](https://huggingface.co/models?pipeline_tag=image-classification). 32 If the model name starts with `timm/`, use 33 `lcc.classifiers.TimmClassifier` instead. 34 n_classes (int): See `lcc.classifiers.WrappedClassifier.__init__`. 35 head_name (str | None, optional): See 36 `lcc.classifiers.WrappedClassifier.__init__`. 37 """ 38 if model_name.startswith("timm/"): 39 raise ValueError( 40 "If the model name starts with `timm/`, use " 41 "`lcc.classifiers.TimmClassifier` instead." 42 ) 43 model = AutoModelForImageClassification.from_pretrained(model_name) 44 super().__init__(model, n_classes, head_name, **kwargs) 45 self.save_hyperparameters() 46 47 @staticmethod 48 def get_image_processor( 49 model_name: str, **__: Any 50 ) -> Callable[[dict[str, Any]], dict[str, Any]]: 51 """ 52 Wraps the HuggingFace `AutoImageProcessor` associated to a given model. 53 54 Args: 55 model_name (str): Model name as in the [HuggingFace model 56 hub](https://huggingface.co/models?pipeline_tag=image-classification). 57 Must not start with `timm/`. 58 59 Returns: 60 A callable that uses a 61 [`transformers.AutoImageProcessor`](https://huggingface.co/docs/transformers/v4.44.2/en/model_doc/auto#transformers.AutoImageProcessor) 62 under the hood. 63 """ 64 hf_transorm = AutoImageProcessor.from_pretrained(model_name) 65 66 def _transform(batch: dict[str, Any]) -> dict[str, Any]: 67 return { 68 k: ( 69 hf_transorm( 70 [img.convert("RGB") for img in v], return_tensors="pt" 71 )["pixel_values"] 72 # TODO: pass image_key from DS ↓ 73 if k in ["img", "image", "jpg", "png"] 74 else v 75 ) 76 for k, v in batch.items() 77 } 78 79 return _transform
Pretrained classifier model loaded from the HuggingFace model hub.
17 def __init__( 18 self, 19 model_name: str, 20 n_classes: int, 21 head_name: str | None = None, 22 **kwargs: Any, 23 ) -> None: 24 """ 25 See also: 26 `lcc.classifiers.WrappedClassifier.__init__` and 27 `lcc.classifiers.BaseClassifier.__init__`. 28 29 Args: 30 model_name (str): Model name as in the [HuggingFace model 31 hub](https://huggingface.co/models?pipeline_tag=image-classification). 32 If the model name starts with `timm/`, use 33 `lcc.classifiers.TimmClassifier` instead. 34 n_classes (int): See `lcc.classifiers.WrappedClassifier.__init__`. 35 head_name (str | None, optional): See 36 `lcc.classifiers.WrappedClassifier.__init__`. 37 """ 38 if model_name.startswith("timm/"): 39 raise ValueError( 40 "If the model name starts with `timm/`, use " 41 "`lcc.classifiers.TimmClassifier` instead." 42 ) 43 model = AutoModelForImageClassification.from_pretrained(model_name) 44 super().__init__(model, n_classes, head_name, **kwargs) 45 self.save_hyperparameters()
See also:
lcc.classifiers.WrappedClassifier.__init__
andlcc.classifiers.BaseClassifier.__init__
.
Arguments:
- model_name (str): Model name as in the HuggingFace model
hub.
If the model name starts with
timm/
, uselcc.classifiers.TimmClassifier
instead. - n_classes (int): See
lcc.classifiers.WrappedClassifier.__init__
. - head_name (str | None, optional): See
lcc.classifiers.WrappedClassifier.__init__
.
47 @staticmethod 48 def get_image_processor( 49 model_name: str, **__: Any 50 ) -> Callable[[dict[str, Any]], dict[str, Any]]: 51 """ 52 Wraps the HuggingFace `AutoImageProcessor` associated to a given model. 53 54 Args: 55 model_name (str): Model name as in the [HuggingFace model 56 hub](https://huggingface.co/models?pipeline_tag=image-classification). 57 Must not start with `timm/`. 58 59 Returns: 60 A callable that uses a 61 [`transformers.AutoImageProcessor`](https://huggingface.co/docs/transformers/v4.44.2/en/model_doc/auto#transformers.AutoImageProcessor) 62 under the hood. 63 """ 64 hf_transorm = AutoImageProcessor.from_pretrained(model_name) 65 66 def _transform(batch: dict[str, Any]) -> dict[str, Any]: 67 return { 68 k: ( 69 hf_transorm( 70 [img.convert("RGB") for img in v], return_tensors="pt" 71 )["pixel_values"] 72 # TODO: pass image_key from DS ↓ 73 if k in ["img", "image", "jpg", "png"] 74 else v 75 ) 76 for k, v in batch.items() 77 } 78 79 return _transform
Wraps the HuggingFace AutoImageProcessor
associated to a given model.
Arguments:
- model_name (str): Model name as in the HuggingFace model
hub.
Must not start with
timm/
.
Returns:
A callable that uses a
transformers.AutoImageProcessor
under the hood.
30@dataclass 31class LatentClusteringData: 32 """ 33 Convenience struct that holds some latent clustering correction data for a 34 given latent space. 35 """ 36 37 loss: LCCLoss 38 """ 39 The actual LCC loss object (which is callable) that compute the LCC loss for 40 LCC correction. 41 """ 42 43 # TODO: Write the y_clst to disk and loading alongside the train dataset, so 44 # that this whole datastructure becomes unnecessary 45 y_clst: np.ndarray 46 """`(N,)` vector of cluster labels."""
Convenience struct that holds some latent clustering correction data for a given latent space.
The actual LCC loss object (which is callable) that compute the LCC loss for LCC correction.
12class TimmClassifier(WrappedClassifier): 13 """ 14 Pretrained classifier model loaded from the [HuggingFace model 15 hub](https://huggingface.co/models?pipeline_tag=image-classification) 16 uploaded by the `timm` team. 17 """ 18 19 image_transform: Callable 20 21 def __init__( 22 self, 23 model_name: str, 24 n_classes: int, 25 head_name: str | None = None, 26 pretrained: bool = True, 27 **kwargs: Any, 28 ) -> None: 29 """ 30 See also: 31 `lcc.classifiers.WrappedClassifier.__init__` and 32 `lcc.classifiers.BaseClassifier.__init__`. 33 34 Args: 35 model_name (str): Model name as in the [HuggingFace model 36 hub](https://huggingface.co/models?pipeline_tag=image-classification). 37 Must start with `timm/`. 38 n_classes (int): See `lcc.classifiers.WrappedClassifier.__init__`. 39 head_name (str | None, optional): See 40 `lcc.classifiers.WrappedClassifier.__init__`. 41 pretrained (bool, optional): Defaults to `True`. 42 """ 43 if not model_name.startswith("timm/"): 44 raise ValueError( 45 "The model isn't a timm model (its name does not start with " 46 "`timm/`). Use `lcc.classifiers.HuggingFaceClassifier` " 47 "instead." 48 ) 49 model = timm.create_model(model_name, pretrained=pretrained) 50 super().__init__(model, n_classes, head_name, **kwargs) 51 self.save_hyperparameters() 52 53 @staticmethod 54 def get_image_processor(model_name: str, **kwargs: Any) -> Callable: 55 """ 56 Wraps the HuggingFace `AutoImageProcessor` associated to a given model. 57 58 See also: 59 [`timm.create_transform`](https://huggingface.co/docs/timm/reference/data#timm.data.create_transform). 60 61 Args: 62 model_name (str): Model name as in the [HuggingFace model 63 hub](https://huggingface.co/models?pipeline_tag=image-classification). 64 Must start with `timm/`. 65 66 Returns: 67 A callable that uses a [Torchvision 68 transform](https://pytorch.org/vision/0.19/transforms.html) under 69 the hood. 70 """ 71 model = timm.create_model(model_name, pretrained=False) 72 conf = timm.data.resolve_model_data_config(model) 73 conf["is_training"], conf["no_aug"] = True, True 74 timm_transform = timm.data.create_transform(**conf) 75 76 def _transform(batch: dict[str, Any]) -> dict[str, Any]: 77 return { 78 k: ( 79 ( 80 [timm_transform(img.convert("RGB")) for img in v] 81 if isinstance(v, list) 82 else timm_transform(v) 83 ) 84 # TODO: pass image_key from DS ↓ 85 if k in ["img", "image", "jpg", "png"] 86 else v 87 ) 88 for k, v in batch.items() 89 } 90 91 return _transform 92 93 def forward(self, inputs: Tensor | Batch, *_: Any, **__: Any) -> Tensor: 94 x: Tensor = ( 95 inputs if isinstance(inputs, Tensor) else inputs[self.image_key] 96 ) 97 return self.model(x.to(self.device)) # type: ignore
Pretrained classifier model loaded from the HuggingFace model
hub
uploaded by the timm
team.
21 def __init__( 22 self, 23 model_name: str, 24 n_classes: int, 25 head_name: str | None = None, 26 pretrained: bool = True, 27 **kwargs: Any, 28 ) -> None: 29 """ 30 See also: 31 `lcc.classifiers.WrappedClassifier.__init__` and 32 `lcc.classifiers.BaseClassifier.__init__`. 33 34 Args: 35 model_name (str): Model name as in the [HuggingFace model 36 hub](https://huggingface.co/models?pipeline_tag=image-classification). 37 Must start with `timm/`. 38 n_classes (int): See `lcc.classifiers.WrappedClassifier.__init__`. 39 head_name (str | None, optional): See 40 `lcc.classifiers.WrappedClassifier.__init__`. 41 pretrained (bool, optional): Defaults to `True`. 42 """ 43 if not model_name.startswith("timm/"): 44 raise ValueError( 45 "The model isn't a timm model (its name does not start with " 46 "`timm/`). Use `lcc.classifiers.HuggingFaceClassifier` " 47 "instead." 48 ) 49 model = timm.create_model(model_name, pretrained=pretrained) 50 super().__init__(model, n_classes, head_name, **kwargs) 51 self.save_hyperparameters()
See also:
lcc.classifiers.WrappedClassifier.__init__
andlcc.classifiers.BaseClassifier.__init__
.
Arguments:
- model_name (str): Model name as in the HuggingFace model
hub.
Must start with
timm/
. - n_classes (int): See
lcc.classifiers.WrappedClassifier.__init__
. - head_name (str | None, optional): See
lcc.classifiers.WrappedClassifier.__init__
. - pretrained (bool, optional): Defaults to
True
.
53 @staticmethod 54 def get_image_processor(model_name: str, **kwargs: Any) -> Callable: 55 """ 56 Wraps the HuggingFace `AutoImageProcessor` associated to a given model. 57 58 See also: 59 [`timm.create_transform`](https://huggingface.co/docs/timm/reference/data#timm.data.create_transform). 60 61 Args: 62 model_name (str): Model name as in the [HuggingFace model 63 hub](https://huggingface.co/models?pipeline_tag=image-classification). 64 Must start with `timm/`. 65 66 Returns: 67 A callable that uses a [Torchvision 68 transform](https://pytorch.org/vision/0.19/transforms.html) under 69 the hood. 70 """ 71 model = timm.create_model(model_name, pretrained=False) 72 conf = timm.data.resolve_model_data_config(model) 73 conf["is_training"], conf["no_aug"] = True, True 74 timm_transform = timm.data.create_transform(**conf) 75 76 def _transform(batch: dict[str, Any]) -> dict[str, Any]: 77 return { 78 k: ( 79 ( 80 [timm_transform(img.convert("RGB")) for img in v] 81 if isinstance(v, list) 82 else timm_transform(v) 83 ) 84 # TODO: pass image_key from DS ↓ 85 if k in ["img", "image", "jpg", "png"] 86 else v 87 ) 88 for k, v in batch.items() 89 } 90 91 return _transform
Wraps the HuggingFace AutoImageProcessor
associated to a given model.
See also:
Arguments:
- model_name (str): Model name as in the HuggingFace model
hub.
Must start with
timm/
.
Returns:
A callable that uses a Torchvision transform under the hood.
93 def forward(self, inputs: Tensor | Batch, *_: Any, **__: Any) -> Tensor: 94 x: Tensor = ( 95 inputs if isinstance(inputs, Tensor) else inputs[self.image_key] 96 ) 97 return self.model(x.to(self.device)) # type: ignore
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
12class TorchvisionClassifier(WrappedClassifier): 13 """ 14 A torchvision classifier wrapped as a `WrappedClassifier`. See 15 https://pytorch.org/vision/stable/models.html#classification . 16 """ 17 18 def __init__( 19 self, 20 model_name: str, 21 n_classes: int, 22 head_name: str | None = None, 23 weights: Any = "DEFAULT", 24 **kwargs: Any, 25 ) -> None: 26 model = get_model(model_name, weights=weights) 27 super().__init__(model, n_classes, head_name, **kwargs) 28 self.save_hyperparameters() 29 30 @staticmethod 31 def get_image_processor( 32 model_name: str, weights: str = "DEFAULT", **__: Any 33 ) -> Callable[[dict[str, Any]], dict[str, Any]]: 34 """ 35 Creates an image processor based on the transform object of the model's 36 chosen weights. For example, 37 38 TorchvisionClassifier.get_image_processor("alexnet") 39 40 is analogous to 41 42 get_model_weights("alexnet")["DEFAULT"].transforms() 43 """ 44 45 transform = tr.Compose( 46 [tr.RGB(), get_model_weights(model_name)[weights].transforms()] 47 ) 48 49 def _processor(batch: dict[str, Any]) -> dict[str, Any]: 50 return { 51 k: ( 52 [transform(img) for img in v] 53 # TODO: pass image_key from DS ↓ 54 if k in ["img", "image", "jpg", "png"] 55 else v 56 ) 57 for k, v in batch.items() 58 } 59 60 return _processor
A torchvision classifier wrapped as a WrappedClassifier
. See
https://pytorch.org/vision/stable/models.html#classification .
18 def __init__( 19 self, 20 model_name: str, 21 n_classes: int, 22 head_name: str | None = None, 23 weights: Any = "DEFAULT", 24 **kwargs: Any, 25 ) -> None: 26 model = get_model(model_name, weights=weights) 27 super().__init__(model, n_classes, head_name, **kwargs) 28 self.save_hyperparameters()
See also:
Arguments:
- model (nn.Module):
- n_classes (int): Number of classes in the dataset on which the model will be trained.
- head_name (str | None, optional): Name of the head submodule in
model
, which is the fully connected (akann.Linear
) layer that outputs the logits. If notNone
, the head is replaced by a new fully connected layer withn_classes
output neurons classes. Specify this to fine-tune a pretrained model on a new dataset with the same or a different number of classes. The name of a submodule can be retried by inspecting the output ofnn.Module.named_modules
orlcc.utils.pretty_print_submodules
. - logit_key (str | None, optional): If the wrapped model outputs a dict-like object instead of a tensor, this key is used to access the actual logits.
30 @staticmethod 31 def get_image_processor( 32 model_name: str, weights: str = "DEFAULT", **__: Any 33 ) -> Callable[[dict[str, Any]], dict[str, Any]]: 34 """ 35 Creates an image processor based on the transform object of the model's 36 chosen weights. For example, 37 38 TorchvisionClassifier.get_image_processor("alexnet") 39 40 is analogous to 41 42 get_model_weights("alexnet")["DEFAULT"].transforms() 43 """ 44 45 transform = tr.Compose( 46 [tr.RGB(), get_model_weights(model_name)[weights].transforms()] 47 ) 48 49 def _processor(batch: dict[str, Any]) -> dict[str, Any]: 50 return { 51 k: ( 52 [transform(img) for img in v] 53 # TODO: pass image_key from DS ↓ 54 if k in ["img", "image", "jpg", "png"] 55 else v 56 ) 57 for k, v in batch.items() 58 } 59 60 return _processor
Creates an image processor based on the transform object of the model's chosen weights. For example,
TorchvisionClassifier.get_image_processor("alexnet")
is analogous to
get_model_weights("alexnet")["DEFAULT"].transforms()
79def validate_lcc_kwargs(lcc_kwargs: dict[str, Any] | None) -> None: 80 """ 81 Makes sure that an LCC hyperparameter dict is valid. Used in constructors of 82 LCC enabled classifiers. 83 84 Args: 85 lcc_kwargs (dict[str, Any] | None): 86 """ 87 if not lcc_kwargs: 88 return 89 if (x := lcc_kwargs.get("weight", 1)) <= 0: 90 raise ValueError(f"LCC weight must be positive, got {x}") 91 if (x := lcc_kwargs.get("interval", 1)) < 1: 92 raise ValueError(f"LCC interval must be at least 1, got {x}") 93 if (x := lcc_kwargs.get("warmup", 0)) < 0: 94 raise ValueError(f"LCC warmup must be at least 0, got {x}") 95 if (x := lcc_kwargs.get("class_selection")) not in LCC_CLASS_SELECTIONS + [ 96 None 97 ]: 98 raise ValueError( 99 f"Invalid class selection policy '{x}'. Available policies are: " 100 + ", ".join(map(lambda a: f"'{a}'", LCC_CLASS_SELECTIONS)) 101 + ", or `None`" 102 ) 103 LCC_LOSS_TYPES = ["exact", "randomized"] 104 if (x := lcc_kwargs.get("loss", "exact")) not in LCC_LOSS_TYPES: 105 raise ValueError( 106 f"Invalid LCC loss type '{x}'. Available types are: " 107 + ", ".join(map(lambda a: f"'{a}'", LCC_LOSS_TYPES)) 108 ) 109 LCC_CLUSTERING_METHODS = ["louvain", "peer_pressure"] 110 if ( 111 x := lcc_kwargs.get("clustering_method", "louvain") 112 ) not in LCC_CLUSTERING_METHODS: 113 raise ValueError( 114 f"Invalid latent clustering method '{x}'. Available methods are: " 115 + ", ".join(map(lambda a: f"'{a}'", LCC_CLUSTERING_METHODS)) 116 )
Makes sure that an LCC hyperparameter dict is valid. Used in constructors of LCC enabled classifiers.
Arguments:
- lcc_kwargs (dict[str, Any] | None):
13class WrappedClassifier(BaseClassifier): 14 """ 15 An image classifier model 16 ([`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)) 17 wrapped inside a `lcc.classifiers.BaseClassifier`. 18 """ 19 20 model: nn.Module 21 22 def __init__( 23 self, 24 model: nn.Module, 25 n_classes: int, 26 head_name: str | None = None, 27 logit_key: str | None = None, 28 **kwargs: Any, 29 ) -> None: 30 """ 31 See also: 32 `lcc.classifiers.BaseClassifier.__init__`. 33 34 Args: 35 model (nn.Module): 36 n_classes (int): Number of classes in the dataset on which the model 37 will be trained. 38 head_name (str | None, optional): Name of the head submodule in 39 `model`, which is the fully connected (aka 40 [`nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) 41 layer that outputs the logits. If not `None`, the head is 42 replaced by a new fully connected layer with `n_classes` output 43 neurons classes. Specify this to fine-tune a pretrained model 44 on a new dataset with the same or a different number of 45 classes. The name of a submodule can be retried by inspecting 46 the output of `nn.Module.named_modules` or 47 `lcc.utils.pretty_print_submodules`. 48 logit_key (str | None, optional): If the wrapped model outputs a 49 dict-like object instead of a tensor, this key is used to access 50 the actual logits. 51 """ 52 self.save_hyperparameters(ignore=["model"]) 53 super().__init__(n_classes, **kwargs) 54 self.model = model 55 if head_name: 56 replace_head(self.model, head_name, n_classes) 57 58 def forward(self, inputs: Tensor | Batch, *_: Any, **__: Any) -> Tensor: 59 image_key = self.hparams["image_key"] 60 logit_key = self.hparams.get("logit_key") 61 x: Tensor = inputs if isinstance(inputs, Tensor) else inputs[image_key] 62 output = self.model(x.to(self.device)) 63 return output if logit_key is None else output[logit_key] # type: ignore 64 65 @property 66 def lcc_submodules(self) -> list[str]: 67 """ 68 Returns the list of submodules considered for LCC, whith correct prefix 69 if needed. 70 """ 71 return ( 72 [] 73 if not self.hparams["lcc_submodules"] 74 else [ 75 (sm if sm.startswith("model.") else "model." + sm) 76 for sm in self.hparams["lcc_submodules"] 77 ] 78 ) 79 80 def on_train_start(self) -> None: 81 self.model.train() 82 super().on_train_start()
An image classifier model
(torch.nn.Module
)
wrapped inside a lcc.classifiers.BaseClassifier
.
22 def __init__( 23 self, 24 model: nn.Module, 25 n_classes: int, 26 head_name: str | None = None, 27 logit_key: str | None = None, 28 **kwargs: Any, 29 ) -> None: 30 """ 31 See also: 32 `lcc.classifiers.BaseClassifier.__init__`. 33 34 Args: 35 model (nn.Module): 36 n_classes (int): Number of classes in the dataset on which the model 37 will be trained. 38 head_name (str | None, optional): Name of the head submodule in 39 `model`, which is the fully connected (aka 40 [`nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) 41 layer that outputs the logits. If not `None`, the head is 42 replaced by a new fully connected layer with `n_classes` output 43 neurons classes. Specify this to fine-tune a pretrained model 44 on a new dataset with the same or a different number of 45 classes. The name of a submodule can be retried by inspecting 46 the output of `nn.Module.named_modules` or 47 `lcc.utils.pretty_print_submodules`. 48 logit_key (str | None, optional): If the wrapped model outputs a 49 dict-like object instead of a tensor, this key is used to access 50 the actual logits. 51 """ 52 self.save_hyperparameters(ignore=["model"]) 53 super().__init__(n_classes, **kwargs) 54 self.model = model 55 if head_name: 56 replace_head(self.model, head_name, n_classes)
See also:
Arguments:
- model (nn.Module):
- n_classes (int): Number of classes in the dataset on which the model will be trained.
- head_name (str | None, optional): Name of the head submodule in
model
, which is the fully connected (akann.Linear
) layer that outputs the logits. If notNone
, the head is replaced by a new fully connected layer withn_classes
output neurons classes. Specify this to fine-tune a pretrained model on a new dataset with the same or a different number of classes. The name of a submodule can be retried by inspecting the output ofnn.Module.named_modules
orlcc.utils.pretty_print_submodules
. - logit_key (str | None, optional): If the wrapped model outputs a dict-like object instead of a tensor, this key is used to access the actual logits.
58 def forward(self, inputs: Tensor | Batch, *_: Any, **__: Any) -> Tensor: 59 image_key = self.hparams["image_key"] 60 logit_key = self.hparams.get("logit_key") 61 x: Tensor = inputs if isinstance(inputs, Tensor) else inputs[image_key] 62 output = self.model(x.to(self.device)) 63 return output if logit_key is None else output[logit_key] # type: ignore
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
65 @property 66 def lcc_submodules(self) -> list[str]: 67 """ 68 Returns the list of submodules considered for LCC, whith correct prefix 69 if needed. 70 """ 71 return ( 72 [] 73 if not self.hparams["lcc_submodules"] 74 else [ 75 (sm if sm.startswith("model.") else "model." + sm) 76 for sm in self.hparams["lcc_submodules"] 77 ] 78 )
Returns the list of submodules considered for LCC, whith correct prefix if needed.