lcc.training
General model training utilities
1"""General model training utilities""" 2 3import hashlib 4import json 5import os 6import uuid 7import warnings 8from datetime import datetime, timedelta 9from pathlib import Path 10from tempfile import TemporaryDirectory 11from typing import Literal 12 13import pandas as pd 14import pytorch_lightning as pl 15import regex as re 16import torch 17import turbo_broccoli as tb 18 19from lcc.classifiers.base import validate_lcc_kwargs 20 21from .classifiers import get_classifier_cls 22from .datasets import HuggingFaceDataset, get_dataset 23from .logging import r0_debug, r0_info 24from .utils import get_reasonable_n_jobs 25 26# from .ema import EMACallback 27 28DEFAULT_MAX_GRAD_NORM = 1.0 29"""For gradient clipping.""" 30 31 32class NoCheckpointFound(Exception): 33 """ 34 Raised by `lcc.training.all_checkpoint_paths` and 35 `lcc.training.best_checkpoint_path` if no checkpoints are found. 36 """ 37 38 39def _dict_sha1(d: dict) -> str: 40 """ 41 Quick and dirty way to get a unique hash for a (potentially nested) 42 dictionary. 43 44 Warning: 45 This method does not sort inner sets. 46 """ 47 h = hashlib.sha1() 48 h.update(json.dumps(d, sort_keys=True).encode("utf-8")) 49 return h.hexdigest() 50 51 52def all_checkpoint_paths(output_path: str | Path) -> list[Path]: 53 """ 54 Returns the sorted (by epoch) list of all checkpoints. The checkpoint files 55 must follow the following pattern: 56 57 epoch=<digits>-step=<digits>.ckpt 58 59 Args: 60 output_path (str | Path): e.g. 61 `out.local/ft/cifar100/microsoft-resnet-18`. There is no assumption 62 on the structure of this folder, as long as it contains `.ckpt` 63 files either directly or with subfolders in between. 64 65 Raises: 66 NoCheckpointFound: If no checkpoint is found 67 """ 68 r, d = re.compile(r"/epoch=(\d+)-step=\d+\.ckpt$"), {} 69 for p in Path(output_path).glob("**/*.ckpt"): 70 if m := re.search(r, str(p)): 71 epoch = int(m.group(1)) 72 d[epoch] = p 73 ckpts = [d[i] for i in sorted(list(d.keys()))] 74 if not ckpts: 75 raise NoCheckpointFound 76 return ckpts 77 78 79def best_checkpoint_path( 80 output_path: str | Path, 81 metric: str = "val/acc", 82 mode: Literal["min", "max"] = "max", 83) -> tuple[Path, int]: 84 """ 85 Returns the path to the best checkpoint. 86 87 Args: 88 output_path (str | Path): e.g. 89 `out.local/ft/cifar100/microsoft-resnet-18`. This folder is expected 90 to contain a `tb_logs` and `csv_logs` folder, either directly or 91 with subfolders in between. 92 metric (str, optional): 93 mode (Literal["min", "max"], optional): 94 95 Returns: 96 A tuple containing the path to the checkpoint file, and the epoch 97 number. 98 """ 99 if not isinstance(output_path, Path): 100 output_path = Path(output_path) 101 ckpts = all_checkpoint_paths(output_path) 102 metrics_path = list(output_path.glob("**/csv_logs/**/metrics.csv"))[0] 103 epoch = best_epoch(metrics_path, metric, mode) 104 return ckpts[epoch], epoch 105 106 107def best_epoch( 108 metrics_path: str | Path, 109 metric: str = "val/acc", 110 mode: Literal["min", "max"] = "max", 111) -> int: 112 """Given the `metrics.csv` path, returns the best epoch index""" 113 df = pd.read_csv(metrics_path) 114 df.drop(columns=["train/loss"], inplace=True) 115 df = df.groupby("epoch").tail(1) 116 df.reset_index(inplace=True, drop=True) 117 return int(df[metric].argmax() if mode == "max" else df[metric].argmin()) 118 119 120def checkpoint_ves(path: str | Path) -> tuple[str, int, int]: 121 """ 122 Given a checkpoint path that looks like e.g. 123 124 out/resnet18/cifar10/model/tb_logs/resnet18/060516dd86294076878cd278cfc59237/checkpoints/epoch=32-step=5181.ckpt 125 126 returns the **v**ersion name (`060516dd86294076878cd278cfc59237`), the 127 number of **e**pochs (32), and the number of **s**teps (5181). 128 """ 129 r = r".*/(\w+)/checkpoints/epoch=(\d+)-step=(\d+).*\.ckpt" 130 if m := re.match(r, str(path)): 131 return str(m.group(1)), int(m.group(2)), int(m.group(3)) 132 raise ValueError(f"Path '{path}' is not a valid checkpoint path") 133 134 135def make_trainer( 136 output_dir: Path | str, 137 model_name: str | None = None, 138 max_epochs: int = 50, 139 save_all_checkpoints: bool = False, 140 stage: Literal["train", "test"] = "train", 141 version: int | str | None = None, 142) -> pl.Trainer: 143 """ 144 Makes a [PyTorch Lightning 145 `Trainer`](https://lightning.ai/docs/pytorch/stable/common/trainer.html) 146 with some sensible defaults. 147 148 Args: 149 output_dir (Path | str): 150 model_name (str): Ignored if `stage` is `test`, but must be set if 151 `stage` is `train`. 152 max_epochs (int, optional): Ignored if `stage` is `test`. 153 save_all_checkpoints (bool, optional): If set to `False`, then only the 154 best checkpoint is saved. 155 stage (str, optional): Either `train` or `test`. 156 """ 157 output_dir = Path(output_dir) 158 159 config = { 160 "default_root_dir": str(output_dir), 161 "log_every_n_steps": 1, 162 } 163 if stage == "train": 164 if model_name is None: 165 raise ValueError("model_name must be set if stage is 'train'") 166 config["accelerator"] = "gpu" 167 config["devices"] = torch.cuda.device_count() 168 config["strategy"] = "ddp" 169 config["max_epochs"] = max_epochs 170 config["gradient_clip_val"] = DEFAULT_MAX_GRAD_NORM 171 config["callbacks"] = [ 172 # EMACallback(), 173 # pl.callbacks.EarlyStopping( 174 # monitor="val/acc", patience=25, mode="max" 175 # ), 176 pl.callbacks.ModelCheckpoint( 177 save_top_k=(-1 if save_all_checkpoints else 1), 178 monitor="val/acc", 179 mode="max", 180 every_n_epochs=1, 181 ), 182 pl.callbacks.TQDMProgressBar(), 183 ] 184 config["logger"] = [ 185 pl.loggers.TensorBoardLogger( 186 str(output_dir / "tb_logs"), 187 name=model_name, 188 default_hp_metric=False, 189 version=version, 190 ), 191 pl.loggers.CSVLogger( 192 str(output_dir / "csv_logs"), 193 name=model_name, 194 version=version, 195 ), 196 ] 197 else: 198 config["devices"] = 1 199 config["num_nodes"] = 1 200 return pl.Trainer(**config) # type: ignore 201 202 203def train( 204 model_name: str, 205 dataset_name: str, 206 output_dir: Path | str, 207 ckpt_path: Path | None = None, 208 ce_weight: float = 1, 209 lcc_submodules: list[str] | None = None, 210 lcc_kwargs: dict | None = None, 211 max_epochs: int = 50, 212 batch_size: int = 256, 213 train_split: str = "train", 214 val_split: str = "val", 215 test_split: str | None = None, 216 image_key: str = "image", 217 label_key: str = "label", 218 logit_key: str | None = "logits", 219 head_name: str | None = None, 220 seed: int | None = None, 221) -> dict: 222 """ 223 Performs fine-tuning on a model, possibly with latent clustering correction. 224 225 Args: 226 model_name (str): The model name as in the [Hugging Face model 227 hub](https://huggingface.co/models?pipeline_tag=image-classification). 228 dataset_name (str): The dataset name as in the [Hugging Face dataset 229 hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification). 230 output_dir (Path | str): 231 ckpt_path (Path | None): If `None`, the correction will start from the 232 weights available on the Hugging Face model hub. 233 ce_weight (float, optional): Weight of the cross-entropy loss against 234 the LCC loss. Ignored if LCC is not performed. Defaults to $1$. 235 lcc_submodules (list[str] | None, optional): List of submodule names 236 where to perform LCC. If empty or `None`, LCC is not performed. This 237 is the only way to enable/disable LCC. Defaults to `None`. 238 lcc_kwargs (dict | None, optional): Optional parameters for LCC. See 239 `lcc.classifiers.BaseClassifier.__init__`. 240 max_epochs (int, optional): Defaults to $50$. 241 batch_size (int, optional): Defaults to $2048$. 242 train_split (str, optional): 243 val_split (str, optional): 244 test_split (str | None, optional): 245 image_key (str, optional): 246 label_key (str, optional): 247 logit_key (str | None, optional): 248 head_name (str | None, optional): Name of the output layer of the model. 249 This must be set if the number of classes in the dataset does not 250 match the number components of the output layer of the model. See 251 also `lcc.classifiers.BaseClassifier.__init__`. 252 seed (int | None, optional): Global seed for both CPU and GPU. If not 253 `None`, this is set globally, so one might consider this as a side 254 effect. 255 """ 256 if seed is not None: 257 r0_info("Setting global seed to {}", seed) 258 torch.manual_seed(seed) 259 260 lcc_kwargs, do_lcc = lcc_kwargs or {}, bool(lcc_submodules) 261 if do_lcc: 262 r0_info("Performing latent cluster correction") 263 validate_lcc_kwargs(lcc_kwargs) 264 265 output_dir = Path(output_dir) 266 _dataset_name = dataset_name.replace("/", "-") 267 _model_name = model_name.replace("/", "-") 268 _output_dir = output_dir / _dataset_name / _model_name 269 _output_dir.mkdir(parents=True, exist_ok=True) 270 271 classifier_cls = get_classifier_cls(model_name) 272 273 if dataset_name.startswith("PRESET:"): 274 dataset_name = dataset_name[7:] 275 r0_info("Using preset dataset name: {}", dataset_name) 276 dataset, _ = get_dataset( 277 dataset_name, 278 image_processor=model_name, 279 batch_size=batch_size, 280 num_workers=get_reasonable_n_jobs(), 281 ) 282 else: 283 dataset = HuggingFaceDataset( 284 dataset_name=dataset_name, 285 fit_split=train_split, 286 val_split=val_split, 287 test_split=test_split, 288 label_key=label_key, 289 image_processor=classifier_cls.get_image_processor(model_name), 290 train_dl_kwargs={ 291 "batch_size": batch_size, 292 "num_workers": get_reasonable_n_jobs(), 293 }, 294 val_dl_kwargs={ 295 "batch_size": batch_size, 296 "num_workers": get_reasonable_n_jobs(), 297 }, 298 ) 299 n_classes = dataset.n_classes() 300 301 model = classifier_cls( 302 model_name=model_name, 303 n_classes=n_classes, 304 head_name=head_name, 305 image_key=image_key, 306 label_key=label_key, 307 logit_key=logit_key, 308 lcc_submodules=lcc_submodules if do_lcc else None, 309 lcc_kwargs=lcc_kwargs if do_lcc else None, 310 ce_weight=ce_weight, 311 ) 312 if isinstance(ckpt_path, Path): 313 model.model = classifier_cls.load_from_checkpoint( # type: ignore 314 ckpt_path 315 ).model 316 r0_info("Loaded checkpoint {}", ckpt_path) 317 r0_debug("Model hyperparameters:\n{}", json.dumps(model.hparams, indent=4)) 318 319 trainer = make_trainer( 320 _output_dir, 321 model_name=_model_name, 322 max_epochs=max_epochs, 323 stage="train", 324 version=str(uuid.uuid4().hex), 325 ) 326 start = datetime.now() 327 with warnings.catch_warnings(): 328 warnings.filterwarnings("ignore", category=UserWarning) 329 trainer.fit(model, dataset) 330 fit_time = datetime.now() - start 331 r0_info("Finished training in {}", fit_time) 332 333 ckpt = Path(trainer.checkpoint_callback.best_model_path) # type: ignore 334 ckpt = ckpt.relative_to(output_dir) 335 v, e, s = checkpoint_ves(ckpt) 336 r0_info("Best checkpoint path: {}", ckpt) 337 r0_info("version={}, best_epoch={}, n_steps={}", v, e, s) 338 339 # TODO: fix testing loop. Right now, every rank reinstanciates a single-node 340 # single-device trainer to run the model on the test dataset. So every rank 341 # is testing the model independently which is stupid. 342 343 with TemporaryDirectory(prefix="lcc-") as tmp: 344 trainer = make_trainer(tmp, stage="test") 345 test_results = trainer.test(model, dataset) 346 347 document: dict = { 348 "__meta__": { 349 "version": 3, 350 "hostname": os.uname().nodename, 351 "datetime": start, 352 }, 353 "dataset": { 354 "name": dataset_name, 355 "n_classes": n_classes, 356 "train_split": train_split, 357 "val_split": val_split, 358 "test_split": test_split, 359 "image_key": image_key, 360 "label_key": label_key, 361 "batch_size": batch_size, 362 }, 363 "model": {"name": model_name, "hparams": dict(model.hparams)}, 364 "training": { 365 "best_checkpoint": { 366 "path": str(ckpt), 367 "version": v, 368 "epoch": e, 369 "n_steps": s, 370 }, 371 "seed": seed, 372 "time": fit_time / timedelta(seconds=1), 373 "test": test_results, 374 }, 375 } 376 document["__meta__"]["hash"] = _dict_sha1( 377 {k: document[k] for k in ["dataset", "model"]} 378 ) 379 tb.save_json(document, _output_dir / f"results.{v}.json") 380 return document
DEFAULT_MAX_GRAD_NORM =
1.0
For gradient clipping.
class
NoCheckpointFound(builtins.Exception):
33class NoCheckpointFound(Exception): 34 """ 35 Raised by `lcc.training.all_checkpoint_paths` and 36 `lcc.training.best_checkpoint_path` if no checkpoints are found. 37 """
Raised by lcc.training.all_checkpoint_paths
and
lcc.training.best_checkpoint_path
if no checkpoints are found.
def
all_checkpoint_paths(output_path: str | pathlib.Path) -> list[pathlib.Path]:
53def all_checkpoint_paths(output_path: str | Path) -> list[Path]: 54 """ 55 Returns the sorted (by epoch) list of all checkpoints. The checkpoint files 56 must follow the following pattern: 57 58 epoch=<digits>-step=<digits>.ckpt 59 60 Args: 61 output_path (str | Path): e.g. 62 `out.local/ft/cifar100/microsoft-resnet-18`. There is no assumption 63 on the structure of this folder, as long as it contains `.ckpt` 64 files either directly or with subfolders in between. 65 66 Raises: 67 NoCheckpointFound: If no checkpoint is found 68 """ 69 r, d = re.compile(r"/epoch=(\d+)-step=\d+\.ckpt$"), {} 70 for p in Path(output_path).glob("**/*.ckpt"): 71 if m := re.search(r, str(p)): 72 epoch = int(m.group(1)) 73 d[epoch] = p 74 ckpts = [d[i] for i in sorted(list(d.keys()))] 75 if not ckpts: 76 raise NoCheckpointFound 77 return ckpts
Returns the sorted (by epoch) list of all checkpoints. The checkpoint files must follow the following pattern:
epoch=<digits>-step=<digits>.ckpt
Arguments:
- output_path (str | Path): e.g.
out.local/ft/cifar100/microsoft-resnet-18
. There is no assumption on the structure of this folder, as long as it contains.ckpt
files either directly or with subfolders in between.
Raises:
- NoCheckpointFound: If no checkpoint is found
def
best_checkpoint_path( output_path: str | pathlib.Path, metric: str = 'val/acc', mode: Literal['min', 'max'] = 'max') -> tuple[pathlib.Path, int]:
80def best_checkpoint_path( 81 output_path: str | Path, 82 metric: str = "val/acc", 83 mode: Literal["min", "max"] = "max", 84) -> tuple[Path, int]: 85 """ 86 Returns the path to the best checkpoint. 87 88 Args: 89 output_path (str | Path): e.g. 90 `out.local/ft/cifar100/microsoft-resnet-18`. This folder is expected 91 to contain a `tb_logs` and `csv_logs` folder, either directly or 92 with subfolders in between. 93 metric (str, optional): 94 mode (Literal["min", "max"], optional): 95 96 Returns: 97 A tuple containing the path to the checkpoint file, and the epoch 98 number. 99 """ 100 if not isinstance(output_path, Path): 101 output_path = Path(output_path) 102 ckpts = all_checkpoint_paths(output_path) 103 metrics_path = list(output_path.glob("**/csv_logs/**/metrics.csv"))[0] 104 epoch = best_epoch(metrics_path, metric, mode) 105 return ckpts[epoch], epoch
Returns the path to the best checkpoint.
Arguments:
- output_path (str | Path): e.g.
out.local/ft/cifar100/microsoft-resnet-18
. This folder is expected to contain atb_logs
andcsv_logs
folder, either directly or with subfolders in between. - metric (str, optional):
- mode (Literal["min", "max"], optional):
Returns:
A tuple containing the path to the checkpoint file, and the epoch number.
def
best_epoch( metrics_path: str | pathlib.Path, metric: str = 'val/acc', mode: Literal['min', 'max'] = 'max') -> int:
108def best_epoch( 109 metrics_path: str | Path, 110 metric: str = "val/acc", 111 mode: Literal["min", "max"] = "max", 112) -> int: 113 """Given the `metrics.csv` path, returns the best epoch index""" 114 df = pd.read_csv(metrics_path) 115 df.drop(columns=["train/loss"], inplace=True) 116 df = df.groupby("epoch").tail(1) 117 df.reset_index(inplace=True, drop=True) 118 return int(df[metric].argmax() if mode == "max" else df[metric].argmin())
Given the metrics.csv
path, returns the best epoch index
def
checkpoint_ves(path: str | pathlib.Path) -> tuple[str, int, int]:
121def checkpoint_ves(path: str | Path) -> tuple[str, int, int]: 122 """ 123 Given a checkpoint path that looks like e.g. 124 125 out/resnet18/cifar10/model/tb_logs/resnet18/060516dd86294076878cd278cfc59237/checkpoints/epoch=32-step=5181.ckpt 126 127 returns the **v**ersion name (`060516dd86294076878cd278cfc59237`), the 128 number of **e**pochs (32), and the number of **s**teps (5181). 129 """ 130 r = r".*/(\w+)/checkpoints/epoch=(\d+)-step=(\d+).*\.ckpt" 131 if m := re.match(r, str(path)): 132 return str(m.group(1)), int(m.group(2)), int(m.group(3)) 133 raise ValueError(f"Path '{path}' is not a valid checkpoint path")
Given a checkpoint path that looks like e.g.
out/resnet18/cifar10/model/tb_logs/resnet18/060516dd86294076878cd278cfc59237/checkpoints/epoch=32-step=5181.ckpt
returns the version name (060516dd86294076878cd278cfc59237
), the
number of epochs (32), and the number of steps (5181).
def
make_trainer( output_dir: pathlib.Path | str, model_name: str | None = None, max_epochs: int = 50, save_all_checkpoints: bool = False, stage: Literal['train', 'test'] = 'train', version: int | str | None = None) -> pytorch_lightning.trainer.trainer.Trainer:
136def make_trainer( 137 output_dir: Path | str, 138 model_name: str | None = None, 139 max_epochs: int = 50, 140 save_all_checkpoints: bool = False, 141 stage: Literal["train", "test"] = "train", 142 version: int | str | None = None, 143) -> pl.Trainer: 144 """ 145 Makes a [PyTorch Lightning 146 `Trainer`](https://lightning.ai/docs/pytorch/stable/common/trainer.html) 147 with some sensible defaults. 148 149 Args: 150 output_dir (Path | str): 151 model_name (str): Ignored if `stage` is `test`, but must be set if 152 `stage` is `train`. 153 max_epochs (int, optional): Ignored if `stage` is `test`. 154 save_all_checkpoints (bool, optional): If set to `False`, then only the 155 best checkpoint is saved. 156 stage (str, optional): Either `train` or `test`. 157 """ 158 output_dir = Path(output_dir) 159 160 config = { 161 "default_root_dir": str(output_dir), 162 "log_every_n_steps": 1, 163 } 164 if stage == "train": 165 if model_name is None: 166 raise ValueError("model_name must be set if stage is 'train'") 167 config["accelerator"] = "gpu" 168 config["devices"] = torch.cuda.device_count() 169 config["strategy"] = "ddp" 170 config["max_epochs"] = max_epochs 171 config["gradient_clip_val"] = DEFAULT_MAX_GRAD_NORM 172 config["callbacks"] = [ 173 # EMACallback(), 174 # pl.callbacks.EarlyStopping( 175 # monitor="val/acc", patience=25, mode="max" 176 # ), 177 pl.callbacks.ModelCheckpoint( 178 save_top_k=(-1 if save_all_checkpoints else 1), 179 monitor="val/acc", 180 mode="max", 181 every_n_epochs=1, 182 ), 183 pl.callbacks.TQDMProgressBar(), 184 ] 185 config["logger"] = [ 186 pl.loggers.TensorBoardLogger( 187 str(output_dir / "tb_logs"), 188 name=model_name, 189 default_hp_metric=False, 190 version=version, 191 ), 192 pl.loggers.CSVLogger( 193 str(output_dir / "csv_logs"), 194 name=model_name, 195 version=version, 196 ), 197 ] 198 else: 199 config["devices"] = 1 200 config["num_nodes"] = 1 201 return pl.Trainer(**config) # type: ignore
Makes a PyTorch Lightning
Trainer
with some sensible defaults.
Arguments:
def
train( model_name: str, dataset_name: str, output_dir: pathlib.Path | str, ckpt_path: pathlib.Path | None = None, ce_weight: float = 1, lcc_submodules: list[str] | None = None, lcc_kwargs: dict | None = None, max_epochs: int = 50, batch_size: int = 256, train_split: str = 'train', val_split: str = 'val', test_split: str | None = None, image_key: str = 'image', label_key: str = 'label', logit_key: str | None = 'logits', head_name: str | None = None, seed: int | None = None) -> dict:
204def train( 205 model_name: str, 206 dataset_name: str, 207 output_dir: Path | str, 208 ckpt_path: Path | None = None, 209 ce_weight: float = 1, 210 lcc_submodules: list[str] | None = None, 211 lcc_kwargs: dict | None = None, 212 max_epochs: int = 50, 213 batch_size: int = 256, 214 train_split: str = "train", 215 val_split: str = "val", 216 test_split: str | None = None, 217 image_key: str = "image", 218 label_key: str = "label", 219 logit_key: str | None = "logits", 220 head_name: str | None = None, 221 seed: int | None = None, 222) -> dict: 223 """ 224 Performs fine-tuning on a model, possibly with latent clustering correction. 225 226 Args: 227 model_name (str): The model name as in the [Hugging Face model 228 hub](https://huggingface.co/models?pipeline_tag=image-classification). 229 dataset_name (str): The dataset name as in the [Hugging Face dataset 230 hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification). 231 output_dir (Path | str): 232 ckpt_path (Path | None): If `None`, the correction will start from the 233 weights available on the Hugging Face model hub. 234 ce_weight (float, optional): Weight of the cross-entropy loss against 235 the LCC loss. Ignored if LCC is not performed. Defaults to $1$. 236 lcc_submodules (list[str] | None, optional): List of submodule names 237 where to perform LCC. If empty or `None`, LCC is not performed. This 238 is the only way to enable/disable LCC. Defaults to `None`. 239 lcc_kwargs (dict | None, optional): Optional parameters for LCC. See 240 `lcc.classifiers.BaseClassifier.__init__`. 241 max_epochs (int, optional): Defaults to $50$. 242 batch_size (int, optional): Defaults to $2048$. 243 train_split (str, optional): 244 val_split (str, optional): 245 test_split (str | None, optional): 246 image_key (str, optional): 247 label_key (str, optional): 248 logit_key (str | None, optional): 249 head_name (str | None, optional): Name of the output layer of the model. 250 This must be set if the number of classes in the dataset does not 251 match the number components of the output layer of the model. See 252 also `lcc.classifiers.BaseClassifier.__init__`. 253 seed (int | None, optional): Global seed for both CPU and GPU. If not 254 `None`, this is set globally, so one might consider this as a side 255 effect. 256 """ 257 if seed is not None: 258 r0_info("Setting global seed to {}", seed) 259 torch.manual_seed(seed) 260 261 lcc_kwargs, do_lcc = lcc_kwargs or {}, bool(lcc_submodules) 262 if do_lcc: 263 r0_info("Performing latent cluster correction") 264 validate_lcc_kwargs(lcc_kwargs) 265 266 output_dir = Path(output_dir) 267 _dataset_name = dataset_name.replace("/", "-") 268 _model_name = model_name.replace("/", "-") 269 _output_dir = output_dir / _dataset_name / _model_name 270 _output_dir.mkdir(parents=True, exist_ok=True) 271 272 classifier_cls = get_classifier_cls(model_name) 273 274 if dataset_name.startswith("PRESET:"): 275 dataset_name = dataset_name[7:] 276 r0_info("Using preset dataset name: {}", dataset_name) 277 dataset, _ = get_dataset( 278 dataset_name, 279 image_processor=model_name, 280 batch_size=batch_size, 281 num_workers=get_reasonable_n_jobs(), 282 ) 283 else: 284 dataset = HuggingFaceDataset( 285 dataset_name=dataset_name, 286 fit_split=train_split, 287 val_split=val_split, 288 test_split=test_split, 289 label_key=label_key, 290 image_processor=classifier_cls.get_image_processor(model_name), 291 train_dl_kwargs={ 292 "batch_size": batch_size, 293 "num_workers": get_reasonable_n_jobs(), 294 }, 295 val_dl_kwargs={ 296 "batch_size": batch_size, 297 "num_workers": get_reasonable_n_jobs(), 298 }, 299 ) 300 n_classes = dataset.n_classes() 301 302 model = classifier_cls( 303 model_name=model_name, 304 n_classes=n_classes, 305 head_name=head_name, 306 image_key=image_key, 307 label_key=label_key, 308 logit_key=logit_key, 309 lcc_submodules=lcc_submodules if do_lcc else None, 310 lcc_kwargs=lcc_kwargs if do_lcc else None, 311 ce_weight=ce_weight, 312 ) 313 if isinstance(ckpt_path, Path): 314 model.model = classifier_cls.load_from_checkpoint( # type: ignore 315 ckpt_path 316 ).model 317 r0_info("Loaded checkpoint {}", ckpt_path) 318 r0_debug("Model hyperparameters:\n{}", json.dumps(model.hparams, indent=4)) 319 320 trainer = make_trainer( 321 _output_dir, 322 model_name=_model_name, 323 max_epochs=max_epochs, 324 stage="train", 325 version=str(uuid.uuid4().hex), 326 ) 327 start = datetime.now() 328 with warnings.catch_warnings(): 329 warnings.filterwarnings("ignore", category=UserWarning) 330 trainer.fit(model, dataset) 331 fit_time = datetime.now() - start 332 r0_info("Finished training in {}", fit_time) 333 334 ckpt = Path(trainer.checkpoint_callback.best_model_path) # type: ignore 335 ckpt = ckpt.relative_to(output_dir) 336 v, e, s = checkpoint_ves(ckpt) 337 r0_info("Best checkpoint path: {}", ckpt) 338 r0_info("version={}, best_epoch={}, n_steps={}", v, e, s) 339 340 # TODO: fix testing loop. Right now, every rank reinstanciates a single-node 341 # single-device trainer to run the model on the test dataset. So every rank 342 # is testing the model independently which is stupid. 343 344 with TemporaryDirectory(prefix="lcc-") as tmp: 345 trainer = make_trainer(tmp, stage="test") 346 test_results = trainer.test(model, dataset) 347 348 document: dict = { 349 "__meta__": { 350 "version": 3, 351 "hostname": os.uname().nodename, 352 "datetime": start, 353 }, 354 "dataset": { 355 "name": dataset_name, 356 "n_classes": n_classes, 357 "train_split": train_split, 358 "val_split": val_split, 359 "test_split": test_split, 360 "image_key": image_key, 361 "label_key": label_key, 362 "batch_size": batch_size, 363 }, 364 "model": {"name": model_name, "hparams": dict(model.hparams)}, 365 "training": { 366 "best_checkpoint": { 367 "path": str(ckpt), 368 "version": v, 369 "epoch": e, 370 "n_steps": s, 371 }, 372 "seed": seed, 373 "time": fit_time / timedelta(seconds=1), 374 "test": test_results, 375 }, 376 } 377 document["__meta__"]["hash"] = _dict_sha1( 378 {k: document[k] for k in ["dataset", "model"]} 379 ) 380 tb.save_json(document, _output_dir / f"results.{v}.json") 381 return document
Performs fine-tuning on a model, possibly with latent clustering correction.
Arguments:
- model_name (str): The model name as in the Hugging Face model hub.
- dataset_name (str): The dataset name as in the Hugging Face dataset hub.
- output_dir (Path | str):
- ckpt_path (Path | None): If
None
, the correction will start from the weights available on the Hugging Face model hub. - ce_weight (float, optional): Weight of the cross-entropy loss against the LCC loss. Ignored if LCC is not performed. Defaults to $1$.
- lcc_submodules (list[str] | None, optional): List of submodule names
where to perform LCC. If empty or
None
, LCC is not performed. This is the only way to enable/disable LCC. Defaults toNone
. - lcc_kwargs (dict | None, optional): Optional parameters for LCC. See
lcc.classifiers.BaseClassifier.__init__
. - max_epochs (int, optional): Defaults to $50$.
- batch_size (int, optional): Defaults to $2048$.
- train_split (str, optional):
- val_split (str, optional):
- test_split (str | None, optional):
- image_key (str, optional):
- label_key (str, optional):
- logit_key (str | None, optional):
- head_name (str | None, optional): Name of the output layer of the model.
This must be set if the number of classes in the dataset does not
match the number components of the output layer of the model. See
also
lcc.classifiers.BaseClassifier.__init__
. - seed (int | None, optional): Global seed for both CPU and GPU. If not
None
, this is set globally, so one might consider this as a side effect.