lcc.training

General model training utilities

  1"""General model training utilities"""
  2
  3import hashlib
  4import json
  5import os
  6import uuid
  7import warnings
  8from datetime import datetime, timedelta
  9from pathlib import Path
 10from tempfile import TemporaryDirectory
 11from typing import Literal
 12
 13import pandas as pd
 14import pytorch_lightning as pl
 15import regex as re
 16import torch
 17import turbo_broccoli as tb
 18
 19from lcc.classifiers.base import validate_lcc_kwargs
 20
 21from .classifiers import get_classifier_cls
 22from .datasets import HuggingFaceDataset, get_dataset
 23from .logging import r0_debug, r0_info
 24from .utils import get_reasonable_n_jobs
 25
 26# from .ema import EMACallback
 27
 28DEFAULT_MAX_GRAD_NORM = 1.0
 29"""For gradient clipping."""
 30
 31
 32class NoCheckpointFound(Exception):
 33    """
 34    Raised by `lcc.training.all_checkpoint_paths` and
 35    `lcc.training.best_checkpoint_path` if no checkpoints are found.
 36    """
 37
 38
 39def _dict_sha1(d: dict) -> str:
 40    """
 41    Quick and dirty way to get a unique hash for a (potentially nested)
 42    dictionary.
 43
 44    Warning:
 45        This method does not sort inner sets.
 46    """
 47    h = hashlib.sha1()
 48    h.update(json.dumps(d, sort_keys=True).encode("utf-8"))
 49    return h.hexdigest()
 50
 51
 52def all_checkpoint_paths(output_path: str | Path) -> list[Path]:
 53    """
 54    Returns the sorted (by epoch) list of all checkpoints. The checkpoint files
 55    must follow the following pattern:
 56
 57        epoch=<digits>-step=<digits>.ckpt
 58
 59    Args:
 60        output_path (str | Path): e.g.
 61            `out.local/ft/cifar100/microsoft-resnet-18`. There is no assumption
 62            on the structure of this folder, as long as it contains `.ckpt`
 63            files either directly or with subfolders in between.
 64
 65    Raises:
 66        NoCheckpointFound: If no checkpoint is found
 67    """
 68    r, d = re.compile(r"/epoch=(\d+)-step=\d+\.ckpt$"), {}
 69    for p in Path(output_path).glob("**/*.ckpt"):
 70        if m := re.search(r, str(p)):
 71            epoch = int(m.group(1))
 72            d[epoch] = p
 73    ckpts = [d[i] for i in sorted(list(d.keys()))]
 74    if not ckpts:
 75        raise NoCheckpointFound
 76    return ckpts
 77
 78
 79def best_checkpoint_path(
 80    output_path: str | Path,
 81    metric: str = "val/acc",
 82    mode: Literal["min", "max"] = "max",
 83) -> tuple[Path, int]:
 84    """
 85    Returns the path to the best checkpoint.
 86
 87    Args:
 88        output_path (str | Path): e.g.
 89            `out.local/ft/cifar100/microsoft-resnet-18`. This folder is expected
 90            to contain a `tb_logs` and `csv_logs` folder, either directly or
 91            with subfolders in between.
 92        metric (str, optional):
 93        mode (Literal["min", "max"], optional):
 94
 95    Returns:
 96        A tuple containing the path to the checkpoint file, and the epoch
 97        number.
 98    """
 99    if not isinstance(output_path, Path):
100        output_path = Path(output_path)
101    ckpts = all_checkpoint_paths(output_path)
102    metrics_path = list(output_path.glob("**/csv_logs/**/metrics.csv"))[0]
103    epoch = best_epoch(metrics_path, metric, mode)
104    return ckpts[epoch], epoch
105
106
107def best_epoch(
108    metrics_path: str | Path,
109    metric: str = "val/acc",
110    mode: Literal["min", "max"] = "max",
111) -> int:
112    """Given the `metrics.csv` path, returns the best epoch index"""
113    df = pd.read_csv(metrics_path)
114    df.drop(columns=["train/loss"], inplace=True)
115    df = df.groupby("epoch").tail(1)
116    df.reset_index(inplace=True, drop=True)
117    return int(df[metric].argmax() if mode == "max" else df[metric].argmin())
118
119
120def checkpoint_ves(path: str | Path) -> tuple[str, int, int]:
121    """
122    Given a checkpoint path that looks like e.g.
123
124        out/resnet18/cifar10/model/tb_logs/resnet18/060516dd86294076878cd278cfc59237/checkpoints/epoch=32-step=5181.ckpt
125
126    returns the **v**ersion name (`060516dd86294076878cd278cfc59237`), the
127    number of **e**pochs (32), and the number of **s**teps (5181).
128    """
129    r = r".*/(\w+)/checkpoints/epoch=(\d+)-step=(\d+).*\.ckpt"
130    if m := re.match(r, str(path)):
131        return str(m.group(1)), int(m.group(2)), int(m.group(3))
132    raise ValueError(f"Path '{path}' is not a valid checkpoint path")
133
134
135def make_trainer(
136    output_dir: Path | str,
137    model_name: str | None = None,
138    max_epochs: int = 50,
139    save_all_checkpoints: bool = False,
140    stage: Literal["train", "test"] = "train",
141    version: int | str | None = None,
142) -> pl.Trainer:
143    """
144    Makes a [PyTorch Lightning
145    `Trainer`](https://lightning.ai/docs/pytorch/stable/common/trainer.html)
146    with some sensible defaults.
147
148    Args:
149        output_dir (Path | str):
150        model_name (str): Ignored if `stage` is `test`, but must be set if
151            `stage` is `train`.
152        max_epochs (int, optional): Ignored if `stage` is `test`.
153        save_all_checkpoints (bool, optional): If set to `False`, then only the
154            best checkpoint is saved.
155        stage (str, optional): Either `train` or `test`.
156    """
157    output_dir = Path(output_dir)
158
159    config = {
160        "default_root_dir": str(output_dir),
161        "log_every_n_steps": 1,
162    }
163    if stage == "train":
164        if model_name is None:
165            raise ValueError("model_name must be set if stage is 'train'")
166        config["accelerator"] = "gpu"
167        config["devices"] = torch.cuda.device_count()
168        config["strategy"] = "ddp"
169        config["max_epochs"] = max_epochs
170        config["gradient_clip_val"] = DEFAULT_MAX_GRAD_NORM
171        config["callbacks"] = [
172            # EMACallback(),
173            # pl.callbacks.EarlyStopping(
174            #     monitor="val/acc", patience=25, mode="max"
175            # ),
176            pl.callbacks.ModelCheckpoint(
177                save_top_k=(-1 if save_all_checkpoints else 1),
178                monitor="val/acc",
179                mode="max",
180                every_n_epochs=1,
181            ),
182            pl.callbacks.TQDMProgressBar(),
183        ]
184        config["logger"] = [
185            pl.loggers.TensorBoardLogger(
186                str(output_dir / "tb_logs"),
187                name=model_name,
188                default_hp_metric=False,
189                version=version,
190            ),
191            pl.loggers.CSVLogger(
192                str(output_dir / "csv_logs"),
193                name=model_name,
194                version=version,
195            ),
196        ]
197    else:
198        config["devices"] = 1
199        config["num_nodes"] = 1
200    return pl.Trainer(**config)  # type: ignore
201
202
203def train(
204    model_name: str,
205    dataset_name: str,
206    output_dir: Path | str,
207    ckpt_path: Path | None = None,
208    ce_weight: float = 1,
209    lcc_submodules: list[str] | None = None,
210    lcc_kwargs: dict | None = None,
211    max_epochs: int = 50,
212    batch_size: int = 256,
213    train_split: str = "train",
214    val_split: str = "val",
215    test_split: str | None = None,
216    image_key: str = "image",
217    label_key: str = "label",
218    logit_key: str | None = "logits",
219    head_name: str | None = None,
220    seed: int | None = None,
221) -> dict:
222    """
223    Performs fine-tuning on a model, possibly with latent clustering correction.
224
225    Args:
226        model_name (str): The model name as in the [Hugging Face model
227            hub](https://huggingface.co/models?pipeline_tag=image-classification).
228        dataset_name (str): The dataset name as in the [Hugging Face dataset
229            hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification).
230        output_dir (Path | str):
231        ckpt_path (Path | None): If `None`, the correction will start from the
232            weights available on the Hugging Face model hub.
233        ce_weight (float, optional): Weight of the cross-entropy loss against
234            the LCC loss. Ignored if LCC is not performed. Defaults to $1$.
235        lcc_submodules (list[str] | None, optional): List of submodule names
236            where to perform LCC. If empty or `None`, LCC is not performed. This
237            is the only way to enable/disable LCC. Defaults to `None`.
238        lcc_kwargs (dict | None, optional): Optional parameters for LCC. See
239            `lcc.classifiers.BaseClassifier.__init__`.
240        max_epochs (int, optional): Defaults to $50$.
241        batch_size (int, optional): Defaults to $2048$.
242        train_split (str, optional):
243        val_split (str, optional):
244        test_split (str | None, optional):
245        image_key (str, optional):
246        label_key (str, optional):
247        logit_key (str | None, optional):
248        head_name (str | None, optional): Name of the output layer of the model.
249            This must be set if the number of classes in the dataset does not
250            match the number components of the output layer of the model. See
251            also `lcc.classifiers.BaseClassifier.__init__`.
252        seed (int | None, optional): Global seed for both CPU and GPU. If not
253            `None`, this is set globally, so one might consider this as a side
254            effect.
255    """
256    if seed is not None:
257        r0_info("Setting global seed to {}", seed)
258        torch.manual_seed(seed)
259
260    lcc_kwargs, do_lcc = lcc_kwargs or {}, bool(lcc_submodules)
261    if do_lcc:
262        r0_info("Performing latent cluster correction")
263        validate_lcc_kwargs(lcc_kwargs)
264
265    output_dir = Path(output_dir)
266    _dataset_name = dataset_name.replace("/", "-")
267    _model_name = model_name.replace("/", "-")
268    _output_dir = output_dir / _dataset_name / _model_name
269    _output_dir.mkdir(parents=True, exist_ok=True)
270
271    classifier_cls = get_classifier_cls(model_name)
272
273    if dataset_name.startswith("PRESET:"):
274        dataset_name = dataset_name[7:]
275        r0_info("Using preset dataset name: {}", dataset_name)
276        dataset, _ = get_dataset(
277            dataset_name,
278            image_processor=model_name,
279            batch_size=batch_size,
280            num_workers=get_reasonable_n_jobs(),
281        )
282    else:
283        dataset = HuggingFaceDataset(
284            dataset_name=dataset_name,
285            fit_split=train_split,
286            val_split=val_split,
287            test_split=test_split,
288            label_key=label_key,
289            image_processor=classifier_cls.get_image_processor(model_name),
290            train_dl_kwargs={
291                "batch_size": batch_size,
292                "num_workers": get_reasonable_n_jobs(),
293            },
294            val_dl_kwargs={
295                "batch_size": batch_size,
296                "num_workers": get_reasonable_n_jobs(),
297            },
298        )
299    n_classes = dataset.n_classes()
300
301    model = classifier_cls(
302        model_name=model_name,
303        n_classes=n_classes,
304        head_name=head_name,
305        image_key=image_key,
306        label_key=label_key,
307        logit_key=logit_key,
308        lcc_submodules=lcc_submodules if do_lcc else None,
309        lcc_kwargs=lcc_kwargs if do_lcc else None,
310        ce_weight=ce_weight,
311    )
312    if isinstance(ckpt_path, Path):
313        model.model = classifier_cls.load_from_checkpoint(  # type: ignore
314            ckpt_path
315        ).model
316        r0_info("Loaded checkpoint {}", ckpt_path)
317    r0_debug("Model hyperparameters:\n{}", json.dumps(model.hparams, indent=4))
318
319    trainer = make_trainer(
320        _output_dir,
321        model_name=_model_name,
322        max_epochs=max_epochs,
323        stage="train",
324        version=str(uuid.uuid4().hex),
325    )
326    start = datetime.now()
327    with warnings.catch_warnings():
328        warnings.filterwarnings("ignore", category=UserWarning)
329        trainer.fit(model, dataset)
330    fit_time = datetime.now() - start
331    r0_info("Finished training in {}", fit_time)
332
333    ckpt = Path(trainer.checkpoint_callback.best_model_path)  # type: ignore
334    ckpt = ckpt.relative_to(output_dir)
335    v, e, s = checkpoint_ves(ckpt)
336    r0_info("Best checkpoint path: {}", ckpt)
337    r0_info("version={}, best_epoch={}, n_steps={}", v, e, s)
338
339    # TODO: fix testing loop. Right now, every rank reinstanciates a single-node
340    # single-device trainer to run the model on the test dataset. So every rank
341    # is testing the model independently which is stupid.
342
343    with TemporaryDirectory(prefix="lcc-") as tmp:
344        trainer = make_trainer(tmp, stage="test")
345        test_results = trainer.test(model, dataset)
346
347    document: dict = {
348        "__meta__": {
349            "version": 3,
350            "hostname": os.uname().nodename,
351            "datetime": start,
352        },
353        "dataset": {
354            "name": dataset_name,
355            "n_classes": n_classes,
356            "train_split": train_split,
357            "val_split": val_split,
358            "test_split": test_split,
359            "image_key": image_key,
360            "label_key": label_key,
361            "batch_size": batch_size,
362        },
363        "model": {"name": model_name, "hparams": dict(model.hparams)},
364        "training": {
365            "best_checkpoint": {
366                "path": str(ckpt),
367                "version": v,
368                "epoch": e,
369                "n_steps": s,
370            },
371            "seed": seed,
372            "time": fit_time / timedelta(seconds=1),
373            "test": test_results,
374        },
375    }
376    document["__meta__"]["hash"] = _dict_sha1(
377        {k: document[k] for k in ["dataset", "model"]}
378    )
379    tb.save_json(document, _output_dir / f"results.{v}.json")
380    return document
DEFAULT_MAX_GRAD_NORM = 1.0

For gradient clipping.

class NoCheckpointFound(builtins.Exception):
33class NoCheckpointFound(Exception):
34    """
35    Raised by `lcc.training.all_checkpoint_paths` and
36    `lcc.training.best_checkpoint_path` if no checkpoints are found.
37    """

Raised by lcc.training.all_checkpoint_paths and lcc.training.best_checkpoint_path if no checkpoints are found.

def all_checkpoint_paths(output_path: str | pathlib.Path) -> list[pathlib.Path]:
53def all_checkpoint_paths(output_path: str | Path) -> list[Path]:
54    """
55    Returns the sorted (by epoch) list of all checkpoints. The checkpoint files
56    must follow the following pattern:
57
58        epoch=<digits>-step=<digits>.ckpt
59
60    Args:
61        output_path (str | Path): e.g.
62            `out.local/ft/cifar100/microsoft-resnet-18`. There is no assumption
63            on the structure of this folder, as long as it contains `.ckpt`
64            files either directly or with subfolders in between.
65
66    Raises:
67        NoCheckpointFound: If no checkpoint is found
68    """
69    r, d = re.compile(r"/epoch=(\d+)-step=\d+\.ckpt$"), {}
70    for p in Path(output_path).glob("**/*.ckpt"):
71        if m := re.search(r, str(p)):
72            epoch = int(m.group(1))
73            d[epoch] = p
74    ckpts = [d[i] for i in sorted(list(d.keys()))]
75    if not ckpts:
76        raise NoCheckpointFound
77    return ckpts

Returns the sorted (by epoch) list of all checkpoints. The checkpoint files must follow the following pattern:

epoch=<digits>-step=<digits>.ckpt
Arguments:
  • output_path (str | Path): e.g. out.local/ft/cifar100/microsoft-resnet-18. There is no assumption on the structure of this folder, as long as it contains .ckpt files either directly or with subfolders in between.
Raises:
  • NoCheckpointFound: If no checkpoint is found
def best_checkpoint_path( output_path: str | pathlib.Path, metric: str = 'val/acc', mode: Literal['min', 'max'] = 'max') -> tuple[pathlib.Path, int]:
 80def best_checkpoint_path(
 81    output_path: str | Path,
 82    metric: str = "val/acc",
 83    mode: Literal["min", "max"] = "max",
 84) -> tuple[Path, int]:
 85    """
 86    Returns the path to the best checkpoint.
 87
 88    Args:
 89        output_path (str | Path): e.g.
 90            `out.local/ft/cifar100/microsoft-resnet-18`. This folder is expected
 91            to contain a `tb_logs` and `csv_logs` folder, either directly or
 92            with subfolders in between.
 93        metric (str, optional):
 94        mode (Literal["min", "max"], optional):
 95
 96    Returns:
 97        A tuple containing the path to the checkpoint file, and the epoch
 98        number.
 99    """
100    if not isinstance(output_path, Path):
101        output_path = Path(output_path)
102    ckpts = all_checkpoint_paths(output_path)
103    metrics_path = list(output_path.glob("**/csv_logs/**/metrics.csv"))[0]
104    epoch = best_epoch(metrics_path, metric, mode)
105    return ckpts[epoch], epoch

Returns the path to the best checkpoint.

Arguments:
  • output_path (str | Path): e.g. out.local/ft/cifar100/microsoft-resnet-18. This folder is expected to contain a tb_logs and csv_logs folder, either directly or with subfolders in between.
  • metric (str, optional):
  • mode (Literal["min", "max"], optional):
Returns:

A tuple containing the path to the checkpoint file, and the epoch number.

def best_epoch( metrics_path: str | pathlib.Path, metric: str = 'val/acc', mode: Literal['min', 'max'] = 'max') -> int:
108def best_epoch(
109    metrics_path: str | Path,
110    metric: str = "val/acc",
111    mode: Literal["min", "max"] = "max",
112) -> int:
113    """Given the `metrics.csv` path, returns the best epoch index"""
114    df = pd.read_csv(metrics_path)
115    df.drop(columns=["train/loss"], inplace=True)
116    df = df.groupby("epoch").tail(1)
117    df.reset_index(inplace=True, drop=True)
118    return int(df[metric].argmax() if mode == "max" else df[metric].argmin())

Given the metrics.csv path, returns the best epoch index

def checkpoint_ves(path: str | pathlib.Path) -> tuple[str, int, int]:
121def checkpoint_ves(path: str | Path) -> tuple[str, int, int]:
122    """
123    Given a checkpoint path that looks like e.g.
124
125        out/resnet18/cifar10/model/tb_logs/resnet18/060516dd86294076878cd278cfc59237/checkpoints/epoch=32-step=5181.ckpt
126
127    returns the **v**ersion name (`060516dd86294076878cd278cfc59237`), the
128    number of **e**pochs (32), and the number of **s**teps (5181).
129    """
130    r = r".*/(\w+)/checkpoints/epoch=(\d+)-step=(\d+).*\.ckpt"
131    if m := re.match(r, str(path)):
132        return str(m.group(1)), int(m.group(2)), int(m.group(3))
133    raise ValueError(f"Path '{path}' is not a valid checkpoint path")

Given a checkpoint path that looks like e.g.

out/resnet18/cifar10/model/tb_logs/resnet18/060516dd86294076878cd278cfc59237/checkpoints/epoch=32-step=5181.ckpt

returns the version name (060516dd86294076878cd278cfc59237), the number of epochs (32), and the number of steps (5181).

def make_trainer( output_dir: pathlib.Path | str, model_name: str | None = None, max_epochs: int = 50, save_all_checkpoints: bool = False, stage: Literal['train', 'test'] = 'train', version: int | str | None = None) -> pytorch_lightning.trainer.trainer.Trainer:
136def make_trainer(
137    output_dir: Path | str,
138    model_name: str | None = None,
139    max_epochs: int = 50,
140    save_all_checkpoints: bool = False,
141    stage: Literal["train", "test"] = "train",
142    version: int | str | None = None,
143) -> pl.Trainer:
144    """
145    Makes a [PyTorch Lightning
146    `Trainer`](https://lightning.ai/docs/pytorch/stable/common/trainer.html)
147    with some sensible defaults.
148
149    Args:
150        output_dir (Path | str):
151        model_name (str): Ignored if `stage` is `test`, but must be set if
152            `stage` is `train`.
153        max_epochs (int, optional): Ignored if `stage` is `test`.
154        save_all_checkpoints (bool, optional): If set to `False`, then only the
155            best checkpoint is saved.
156        stage (str, optional): Either `train` or `test`.
157    """
158    output_dir = Path(output_dir)
159
160    config = {
161        "default_root_dir": str(output_dir),
162        "log_every_n_steps": 1,
163    }
164    if stage == "train":
165        if model_name is None:
166            raise ValueError("model_name must be set if stage is 'train'")
167        config["accelerator"] = "gpu"
168        config["devices"] = torch.cuda.device_count()
169        config["strategy"] = "ddp"
170        config["max_epochs"] = max_epochs
171        config["gradient_clip_val"] = DEFAULT_MAX_GRAD_NORM
172        config["callbacks"] = [
173            # EMACallback(),
174            # pl.callbacks.EarlyStopping(
175            #     monitor="val/acc", patience=25, mode="max"
176            # ),
177            pl.callbacks.ModelCheckpoint(
178                save_top_k=(-1 if save_all_checkpoints else 1),
179                monitor="val/acc",
180                mode="max",
181                every_n_epochs=1,
182            ),
183            pl.callbacks.TQDMProgressBar(),
184        ]
185        config["logger"] = [
186            pl.loggers.TensorBoardLogger(
187                str(output_dir / "tb_logs"),
188                name=model_name,
189                default_hp_metric=False,
190                version=version,
191            ),
192            pl.loggers.CSVLogger(
193                str(output_dir / "csv_logs"),
194                name=model_name,
195                version=version,
196            ),
197        ]
198    else:
199        config["devices"] = 1
200        config["num_nodes"] = 1
201    return pl.Trainer(**config)  # type: ignore

Makes a PyTorch Lightning Trainer with some sensible defaults.

Arguments:
  • output_dir (Path | str):
  • model_name (str): Ignored if stage is test, but must be set if stage is train.
  • max_epochs (int, optional): Ignored if stage is test.
  • save_all_checkpoints (bool, optional): If set to False, then only the best checkpoint is saved.
  • stage (str, optional): Either train or test.
def train( model_name: str, dataset_name: str, output_dir: pathlib.Path | str, ckpt_path: pathlib.Path | None = None, ce_weight: float = 1, lcc_submodules: list[str] | None = None, lcc_kwargs: dict | None = None, max_epochs: int = 50, batch_size: int = 256, train_split: str = 'train', val_split: str = 'val', test_split: str | None = None, image_key: str = 'image', label_key: str = 'label', logit_key: str | None = 'logits', head_name: str | None = None, seed: int | None = None) -> dict:
204def train(
205    model_name: str,
206    dataset_name: str,
207    output_dir: Path | str,
208    ckpt_path: Path | None = None,
209    ce_weight: float = 1,
210    lcc_submodules: list[str] | None = None,
211    lcc_kwargs: dict | None = None,
212    max_epochs: int = 50,
213    batch_size: int = 256,
214    train_split: str = "train",
215    val_split: str = "val",
216    test_split: str | None = None,
217    image_key: str = "image",
218    label_key: str = "label",
219    logit_key: str | None = "logits",
220    head_name: str | None = None,
221    seed: int | None = None,
222) -> dict:
223    """
224    Performs fine-tuning on a model, possibly with latent clustering correction.
225
226    Args:
227        model_name (str): The model name as in the [Hugging Face model
228            hub](https://huggingface.co/models?pipeline_tag=image-classification).
229        dataset_name (str): The dataset name as in the [Hugging Face dataset
230            hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification).
231        output_dir (Path | str):
232        ckpt_path (Path | None): If `None`, the correction will start from the
233            weights available on the Hugging Face model hub.
234        ce_weight (float, optional): Weight of the cross-entropy loss against
235            the LCC loss. Ignored if LCC is not performed. Defaults to $1$.
236        lcc_submodules (list[str] | None, optional): List of submodule names
237            where to perform LCC. If empty or `None`, LCC is not performed. This
238            is the only way to enable/disable LCC. Defaults to `None`.
239        lcc_kwargs (dict | None, optional): Optional parameters for LCC. See
240            `lcc.classifiers.BaseClassifier.__init__`.
241        max_epochs (int, optional): Defaults to $50$.
242        batch_size (int, optional): Defaults to $2048$.
243        train_split (str, optional):
244        val_split (str, optional):
245        test_split (str | None, optional):
246        image_key (str, optional):
247        label_key (str, optional):
248        logit_key (str | None, optional):
249        head_name (str | None, optional): Name of the output layer of the model.
250            This must be set if the number of classes in the dataset does not
251            match the number components of the output layer of the model. See
252            also `lcc.classifiers.BaseClassifier.__init__`.
253        seed (int | None, optional): Global seed for both CPU and GPU. If not
254            `None`, this is set globally, so one might consider this as a side
255            effect.
256    """
257    if seed is not None:
258        r0_info("Setting global seed to {}", seed)
259        torch.manual_seed(seed)
260
261    lcc_kwargs, do_lcc = lcc_kwargs or {}, bool(lcc_submodules)
262    if do_lcc:
263        r0_info("Performing latent cluster correction")
264        validate_lcc_kwargs(lcc_kwargs)
265
266    output_dir = Path(output_dir)
267    _dataset_name = dataset_name.replace("/", "-")
268    _model_name = model_name.replace("/", "-")
269    _output_dir = output_dir / _dataset_name / _model_name
270    _output_dir.mkdir(parents=True, exist_ok=True)
271
272    classifier_cls = get_classifier_cls(model_name)
273
274    if dataset_name.startswith("PRESET:"):
275        dataset_name = dataset_name[7:]
276        r0_info("Using preset dataset name: {}", dataset_name)
277        dataset, _ = get_dataset(
278            dataset_name,
279            image_processor=model_name,
280            batch_size=batch_size,
281            num_workers=get_reasonable_n_jobs(),
282        )
283    else:
284        dataset = HuggingFaceDataset(
285            dataset_name=dataset_name,
286            fit_split=train_split,
287            val_split=val_split,
288            test_split=test_split,
289            label_key=label_key,
290            image_processor=classifier_cls.get_image_processor(model_name),
291            train_dl_kwargs={
292                "batch_size": batch_size,
293                "num_workers": get_reasonable_n_jobs(),
294            },
295            val_dl_kwargs={
296                "batch_size": batch_size,
297                "num_workers": get_reasonable_n_jobs(),
298            },
299        )
300    n_classes = dataset.n_classes()
301
302    model = classifier_cls(
303        model_name=model_name,
304        n_classes=n_classes,
305        head_name=head_name,
306        image_key=image_key,
307        label_key=label_key,
308        logit_key=logit_key,
309        lcc_submodules=lcc_submodules if do_lcc else None,
310        lcc_kwargs=lcc_kwargs if do_lcc else None,
311        ce_weight=ce_weight,
312    )
313    if isinstance(ckpt_path, Path):
314        model.model = classifier_cls.load_from_checkpoint(  # type: ignore
315            ckpt_path
316        ).model
317        r0_info("Loaded checkpoint {}", ckpt_path)
318    r0_debug("Model hyperparameters:\n{}", json.dumps(model.hparams, indent=4))
319
320    trainer = make_trainer(
321        _output_dir,
322        model_name=_model_name,
323        max_epochs=max_epochs,
324        stage="train",
325        version=str(uuid.uuid4().hex),
326    )
327    start = datetime.now()
328    with warnings.catch_warnings():
329        warnings.filterwarnings("ignore", category=UserWarning)
330        trainer.fit(model, dataset)
331    fit_time = datetime.now() - start
332    r0_info("Finished training in {}", fit_time)
333
334    ckpt = Path(trainer.checkpoint_callback.best_model_path)  # type: ignore
335    ckpt = ckpt.relative_to(output_dir)
336    v, e, s = checkpoint_ves(ckpt)
337    r0_info("Best checkpoint path: {}", ckpt)
338    r0_info("version={}, best_epoch={}, n_steps={}", v, e, s)
339
340    # TODO: fix testing loop. Right now, every rank reinstanciates a single-node
341    # single-device trainer to run the model on the test dataset. So every rank
342    # is testing the model independently which is stupid.
343
344    with TemporaryDirectory(prefix="lcc-") as tmp:
345        trainer = make_trainer(tmp, stage="test")
346        test_results = trainer.test(model, dataset)
347
348    document: dict = {
349        "__meta__": {
350            "version": 3,
351            "hostname": os.uname().nodename,
352            "datetime": start,
353        },
354        "dataset": {
355            "name": dataset_name,
356            "n_classes": n_classes,
357            "train_split": train_split,
358            "val_split": val_split,
359            "test_split": test_split,
360            "image_key": image_key,
361            "label_key": label_key,
362            "batch_size": batch_size,
363        },
364        "model": {"name": model_name, "hparams": dict(model.hparams)},
365        "training": {
366            "best_checkpoint": {
367                "path": str(ckpt),
368                "version": v,
369                "epoch": e,
370                "n_steps": s,
371            },
372            "seed": seed,
373            "time": fit_time / timedelta(seconds=1),
374            "test": test_results,
375        },
376    }
377    document["__meta__"]["hash"] = _dict_sha1(
378        {k: document[k] for k in ["dataset", "model"]}
379    )
380    tb.save_json(document, _output_dir / f"results.{v}.json")
381    return document

Performs fine-tuning on a model, possibly with latent clustering correction.

Arguments:
  • model_name (str): The model name as in the Hugging Face model hub.
  • dataset_name (str): The dataset name as in the Hugging Face dataset hub.
  • output_dir (Path | str):
  • ckpt_path (Path | None): If None, the correction will start from the weights available on the Hugging Face model hub.
  • ce_weight (float, optional): Weight of the cross-entropy loss against the LCC loss. Ignored if LCC is not performed. Defaults to $1$.
  • lcc_submodules (list[str] | None, optional): List of submodule names where to perform LCC. If empty or None, LCC is not performed. This is the only way to enable/disable LCC. Defaults to None.
  • lcc_kwargs (dict | None, optional): Optional parameters for LCC. See lcc.classifiers.BaseClassifier.__init__.
  • max_epochs (int, optional): Defaults to $50$.
  • batch_size (int, optional): Defaults to $2048$.
  • train_split (str, optional):
  • val_split (str, optional):
  • test_split (str | None, optional):
  • image_key (str, optional):
  • label_key (str, optional):
  • logit_key (str | None, optional):
  • head_name (str | None, optional): Name of the output layer of the model. This must be set if the number of classes in the dataset does not match the number components of the output layer of the model. See also lcc.classifiers.BaseClassifier.__init__.
  • seed (int | None, optional): Global seed for both CPU and GPU. If not None, this is set globally, so one might consider this as a side effect.