lcc.datasets

Custom Lightning datamodules

 1"""Custom Lightning datamodules"""
 2
 3from .batched_tensor import BatchedTensorDataset
 4from .huggingface import HuggingFaceDataset
 5from .preset import DATASET_PRESETS_CONFIGURATIONS, get_dataset
 6from .utils import dl_head, flatten_batches
 7from .wrapped import DEFAULT_DATALOADER_KWARGS, WrappedDataset
 8
 9__all__ = [
10    "BatchedTensorDataset",
11    "DATASET_PRESETS_CONFIGURATIONS",
12    "DEFAULT_DATALOADER_KWARGS",
13    "dl_head",
14    "flatten_batches",
15    "get_dataset",
16    "HuggingFaceDataset",
17    "WrappedDataset",
18]
class BatchedTensorDataset(torch.utils.data.dataset.Dataset[+T_co], typing.Iterable[+T_co]):
 42class BatchedTensorDataset(IterableDataset):
 43    """
 44    Dataset that load tensor batches produced by
 45    `lcc.datasets.save_tensor_batched`.
 46    """
 47
 48    key: str
 49    paths: list[Path]
 50
 51    _len: int | None = None
 52    """Cached length appriximation, see __len__."""
 53
 54    def __init__(
 55        self,
 56        batch_dir: str | Path,
 57        prefix: str = "batch",
 58        extension: str = "st",
 59        key: str = "",
 60    ):
 61        """
 62        See `lcc.datasets.save_tensor_batched` for the precise meaning of the
 63        argument. But in a few words, this dataset will load batches from
 64        [Safetensors](https://huggingface.co/docs/safetensors/index) files named
 65        after the following scheme:
 66
 67            batch_dir/<prefix>.<unique_id>.<extension>
 68
 69        Safetensor files are essentially dictionaries, and each batch file is
 70        expected to contain:
 71        * `key` (see below): some `(B, ...)` data tensor.
 72        * `"_idx"`: a `(B,)` integer tensor.
 73
 74        When iterating over this dataset, it will yield `(data, idx)` pairs,
 75        where `data` is a row of the dataset, and `idx` is an int (as a `(,)`
 76        int tensor).
 77
 78        Warning:
 79            The list of batch files will be globed from `batch_dir` upon
 80            instantiation of this dataset class. It will not be updated
 81            afterwards.
 82
 83        Warning:
 84            Batches are loaded in the order they are found in the filesystem.
 85            Don't expect this order to be the same as the order in which the
 86            data has been generated.
 87
 88        Args:
 89            batch_dir (str | Path):
 90            prefix (str, optional):
 91            extension (str, optional): Without the first `.`. Defaults to `st`.
 92            key (str, optional): The key to use when saving the file. Batches
 93                are stored in safetensor files, which are essentially
 94                dictionaries.  This arg specifies which key contains the data of
 95                interest. Cannot be `"_idx"`.
 96        """
 97        if key == "_idx":
 98            raise ValueError("Key cannot be '_idx'.")
 99        self.paths = list(Path(batch_dir).glob(f"{prefix}.*.{extension}"))
100        self.key = key
101
102    def __iter__(self) -> Iterator[tuple[Tensor, Tensor]]:
103        for path in self.paths:
104            data = st.load_file(path)
105            for z, i in zip(data[self.key], data["_idx"]):
106                yield z, i
107
108    def __len__(self) -> int:
109        """
110        Returns an **approximation** of the length of this dataset. Loads the
111        first batch and multiplies its length by the number of batch files.
112        """
113        if self._len is None:
114            self._len = sum(len(st.load_file(p)["_idx"]) for p in self.paths)
115        return self._len
116
117    def distribute(
118        self, strategy: Strategy | Fabric | None
119    ) -> "BatchedTensorDataset":
120        """
121        Creates a *copy* of a subset of this dataset so that every rank has a
122        different subset. Does not modify the current dataset.
123
124        If the strategy is not a `ParallelStrategy` or a Lightning `Fabric`, or
125        if the world size is less than 2, this method returns `self` (NOT a copy
126        of `self`).
127        """
128        if (
129            not isinstance(strategy, (ParallelStrategy, Fabric))
130            or strategy.world_size < 2
131        ):
132            return self
133        ws, gr = strategy.world_size, strategy.global_rank
134        ds = deepcopy(self)
135        ds.paths, ds._len = ds.paths[gr::ws], None
136        return ds
137
138    def extract_idx(
139        self, tqdm_style: TqdmStyle = None
140    ) -> tuple[IterableDataset, Tensor]:
141        """
142        Splits this dataset in two. The first one yeilds the data, the second
143        one yields the indices. Then, the index dataset is unrolled into a
144        single index tensor (which therefore has shape `(N,)`, where `N` is the
145        shape of the dataset).
146        """
147        a, b = _ProjectionDataset(self, 0), _ProjectionDataset(self, 1)
148        dl = DataLoader(b, batch_size=1024, num_workers=1)
149        dl = make_tqdm(tqdm_style)(dl, "Extracting indices")
150        # TODO: setting num_workers to > 1 makes the index tensor n_workers
151        # times too long... problem with tqdm?
152        return a, torch.cat(list(dl), dim=0)
153
154    def load(
155        self,
156        batch_size: int = 256,
157        num_workers: int = 0,
158        tqdm_style: TqdmStyle = None,
159    ) -> tuple[Tensor, Tensor]:
160        """
161        Loads a batched tensor in one go. See `BatchedTensorDataset.save`.
162
163        Args:
164            batch_size (int, optional): Defaults to 256. Does not impact the
165                actual result.
166            num_workers (int, optional): Defaults to 0, meaning single-process
167                data loading.
168            tqdm_style (TqdmStyle,
169                optional):
170
171        Returns:
172            A tuple `(data, idx)` where `data` is a `(N, ...)` tensor and `idx`
173            is `(N,)` int tensor.
174        """
175        dl = DataLoader(self, batch_size=batch_size, num_workers=num_workers)
176        u = make_tqdm(tqdm_style)(dl, "Loading")
177        v = list(zip(*u))
178        return torch.cat(v[0], dim=0), torch.cat(v[1], dim=0)
179
180    @staticmethod
181    def save(
182        x: ArrayLike,
183        output_dir: str | Path,
184        prefix: str = "batch",
185        extension: str = "st",
186        key: str = "",
187        batch_size: int = 256,
188        tqdm_style: TqdmStyle = None,
189    ) -> None:
190        """
191        Saves a tensor in batches of `batch_size` elements. The files will be
192        named
193
194            output_dir/<prefix>.<batch_idx>.<extension>
195
196        The batches are saved using
197        [Safetensors](https://huggingface.co/docs/safetensors/index).
198        Safetensors files are essentially dictionaries, and each batch file is
199        structured as follows:
200        * `key` (see below): some `(batch_size, ...)` slice from `x`,
201        * `"_idx"`: a `(batch_size,)` integer tensor containing the indices in
202          `x`.
203
204        The `batch_idx` string is 4 digits long, so would be great if you could
205        adjust the batch size so that there are less than 10000 batches :]
206
207        Args:
208            x (ArrayLike):
209            output_dir (str):
210            prefix (str, optional):
211            extension (str, optional): Without the first `.`. Defaults to `st`.
212            key (str, optional): The key to use when saving the file. Batches
213                are stored in safetensor files, which are essentially
214                dictionaries.  This arg specifies which key contains the data of
215                interest. Cannot be `"_idx"`.
216            batch_size (int, optional): Defaults to $256$.
217            tqdm_style (TqdmStyle,
218                optional): Progress bar style.
219        """
220        batches = to_tensor(x).split(batch_size)
221        for i, batch in enumerate(make_tqdm(tqdm_style)(batches, "Saving")):
222            data = {
223                key: batch,
224                "_idx": torch.arange(i * batch_size, (i + 1) * batch_size),
225            }
226            st.save_file(
227                data, Path(output_dir) / f"{prefix}.{i:04}.{extension}"
228            )

Dataset that load tensor batches produced by lcc.datasets.save_tensor_batched.

BatchedTensorDataset( batch_dir: str | pathlib.Path, prefix: str = 'batch', extension: str = 'st', key: str = '')
 54    def __init__(
 55        self,
 56        batch_dir: str | Path,
 57        prefix: str = "batch",
 58        extension: str = "st",
 59        key: str = "",
 60    ):
 61        """
 62        See `lcc.datasets.save_tensor_batched` for the precise meaning of the
 63        argument. But in a few words, this dataset will load batches from
 64        [Safetensors](https://huggingface.co/docs/safetensors/index) files named
 65        after the following scheme:
 66
 67            batch_dir/<prefix>.<unique_id>.<extension>
 68
 69        Safetensor files are essentially dictionaries, and each batch file is
 70        expected to contain:
 71        * `key` (see below): some `(B, ...)` data tensor.
 72        * `"_idx"`: a `(B,)` integer tensor.
 73
 74        When iterating over this dataset, it will yield `(data, idx)` pairs,
 75        where `data` is a row of the dataset, and `idx` is an int (as a `(,)`
 76        int tensor).
 77
 78        Warning:
 79            The list of batch files will be globed from `batch_dir` upon
 80            instantiation of this dataset class. It will not be updated
 81            afterwards.
 82
 83        Warning:
 84            Batches are loaded in the order they are found in the filesystem.
 85            Don't expect this order to be the same as the order in which the
 86            data has been generated.
 87
 88        Args:
 89            batch_dir (str | Path):
 90            prefix (str, optional):
 91            extension (str, optional): Without the first `.`. Defaults to `st`.
 92            key (str, optional): The key to use when saving the file. Batches
 93                are stored in safetensor files, which are essentially
 94                dictionaries.  This arg specifies which key contains the data of
 95                interest. Cannot be `"_idx"`.
 96        """
 97        if key == "_idx":
 98            raise ValueError("Key cannot be '_idx'.")
 99        self.paths = list(Path(batch_dir).glob(f"{prefix}.*.{extension}"))
100        self.key = key

See lcc.datasets.save_tensor_batched for the precise meaning of the argument. But in a few words, this dataset will load batches from Safetensors files named after the following scheme:

batch_dir/<prefix>.<unique_id>.<extension>

Safetensor files are essentially dictionaries, and each batch file is expected to contain:

  • key (see below): some (B, ...) data tensor.
  • "_idx": a (B,) integer tensor.

When iterating over this dataset, it will yield (data, idx) pairs, where data is a row of the dataset, and idx is an int (as a (,) int tensor).

Warning:

The list of batch files will be globed from batch_dir upon instantiation of this dataset class. It will not be updated afterwards.

Warning:

Batches are loaded in the order they are found in the filesystem. Don't expect this order to be the same as the order in which the data has been generated.

Arguments:
  • batch_dir (str | Path):
  • prefix (str, optional):
  • extension (str, optional): Without the first .. Defaults to st.
  • key (str, optional): The key to use when saving the file. Batches are stored in safetensor files, which are essentially dictionaries. This arg specifies which key contains the data of interest. Cannot be "_idx".
key: str
paths: list[pathlib.Path]
def distribute( self, strategy: pytorch_lightning.strategies.strategy.Strategy | lightning_fabric.fabric.Fabric | None) -> BatchedTensorDataset:
117    def distribute(
118        self, strategy: Strategy | Fabric | None
119    ) -> "BatchedTensorDataset":
120        """
121        Creates a *copy* of a subset of this dataset so that every rank has a
122        different subset. Does not modify the current dataset.
123
124        If the strategy is not a `ParallelStrategy` or a Lightning `Fabric`, or
125        if the world size is less than 2, this method returns `self` (NOT a copy
126        of `self`).
127        """
128        if (
129            not isinstance(strategy, (ParallelStrategy, Fabric))
130            or strategy.world_size < 2
131        ):
132            return self
133        ws, gr = strategy.world_size, strategy.global_rank
134        ds = deepcopy(self)
135        ds.paths, ds._len = ds.paths[gr::ws], None
136        return ds

Creates a copy of a subset of this dataset so that every rank has a different subset. Does not modify the current dataset.

If the strategy is not a ParallelStrategy or a Lightning Fabric, or if the world size is less than 2, this method returns self (NOT a copy of self).

def extract_idx( self, tqdm_style: Optional[Literal['notebook', 'console', 'none']] = None) -> tuple[torch.utils.data.dataset.IterableDataset, torch.Tensor]:
138    def extract_idx(
139        self, tqdm_style: TqdmStyle = None
140    ) -> tuple[IterableDataset, Tensor]:
141        """
142        Splits this dataset in two. The first one yeilds the data, the second
143        one yields the indices. Then, the index dataset is unrolled into a
144        single index tensor (which therefore has shape `(N,)`, where `N` is the
145        shape of the dataset).
146        """
147        a, b = _ProjectionDataset(self, 0), _ProjectionDataset(self, 1)
148        dl = DataLoader(b, batch_size=1024, num_workers=1)
149        dl = make_tqdm(tqdm_style)(dl, "Extracting indices")
150        # TODO: setting num_workers to > 1 makes the index tensor n_workers
151        # times too long... problem with tqdm?
152        return a, torch.cat(list(dl), dim=0)

Splits this dataset in two. The first one yeilds the data, the second one yields the indices. Then, the index dataset is unrolled into a single index tensor (which therefore has shape (N,), where N is the shape of the dataset).

def load( self, batch_size: int = 256, num_workers: int = 0, tqdm_style: Optional[Literal['notebook', 'console', 'none']] = None) -> tuple[torch.Tensor, torch.Tensor]:
154    def load(
155        self,
156        batch_size: int = 256,
157        num_workers: int = 0,
158        tqdm_style: TqdmStyle = None,
159    ) -> tuple[Tensor, Tensor]:
160        """
161        Loads a batched tensor in one go. See `BatchedTensorDataset.save`.
162
163        Args:
164            batch_size (int, optional): Defaults to 256. Does not impact the
165                actual result.
166            num_workers (int, optional): Defaults to 0, meaning single-process
167                data loading.
168            tqdm_style (TqdmStyle,
169                optional):
170
171        Returns:
172            A tuple `(data, idx)` where `data` is a `(N, ...)` tensor and `idx`
173            is `(N,)` int tensor.
174        """
175        dl = DataLoader(self, batch_size=batch_size, num_workers=num_workers)
176        u = make_tqdm(tqdm_style)(dl, "Loading")
177        v = list(zip(*u))
178        return torch.cat(v[0], dim=0), torch.cat(v[1], dim=0)

Loads a batched tensor in one go. See BatchedTensorDataset.save.

Arguments:
  • batch_size (int, optional): Defaults to 256. Does not impact the actual result.
  • num_workers (int, optional): Defaults to 0, meaning single-process data loading.
  • tqdm_style (TqdmStyle, optional):
Returns:

A tuple (data, idx) where data is a (N, ...) tensor and idx is (N,) int tensor.

@staticmethod
def save( x: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], output_dir: str | pathlib.Path, prefix: str = 'batch', extension: str = 'st', key: str = '', batch_size: int = 256, tqdm_style: Optional[Literal['notebook', 'console', 'none']] = None) -> None:
180    @staticmethod
181    def save(
182        x: ArrayLike,
183        output_dir: str | Path,
184        prefix: str = "batch",
185        extension: str = "st",
186        key: str = "",
187        batch_size: int = 256,
188        tqdm_style: TqdmStyle = None,
189    ) -> None:
190        """
191        Saves a tensor in batches of `batch_size` elements. The files will be
192        named
193
194            output_dir/<prefix>.<batch_idx>.<extension>
195
196        The batches are saved using
197        [Safetensors](https://huggingface.co/docs/safetensors/index).
198        Safetensors files are essentially dictionaries, and each batch file is
199        structured as follows:
200        * `key` (see below): some `(batch_size, ...)` slice from `x`,
201        * `"_idx"`: a `(batch_size,)` integer tensor containing the indices in
202          `x`.
203
204        The `batch_idx` string is 4 digits long, so would be great if you could
205        adjust the batch size so that there are less than 10000 batches :]
206
207        Args:
208            x (ArrayLike):
209            output_dir (str):
210            prefix (str, optional):
211            extension (str, optional): Without the first `.`. Defaults to `st`.
212            key (str, optional): The key to use when saving the file. Batches
213                are stored in safetensor files, which are essentially
214                dictionaries.  This arg specifies which key contains the data of
215                interest. Cannot be `"_idx"`.
216            batch_size (int, optional): Defaults to $256$.
217            tqdm_style (TqdmStyle,
218                optional): Progress bar style.
219        """
220        batches = to_tensor(x).split(batch_size)
221        for i, batch in enumerate(make_tqdm(tqdm_style)(batches, "Saving")):
222            data = {
223                key: batch,
224                "_idx": torch.arange(i * batch_size, (i + 1) * batch_size),
225            }
226            st.save_file(
227                data, Path(output_dir) / f"{prefix}.{i:04}.{extension}"
228            )

Saves a tensor in batches of batch_size elements. The files will be named

output_dir/<prefix>.<batch_idx>.<extension>

The batches are saved using Safetensors. Safetensors files are essentially dictionaries, and each batch file is structured as follows:

  • key (see below): some (batch_size, ...) slice from x,
  • "_idx": a (batch_size,) integer tensor containing the indices in x.

The batch_idx string is 4 digits long, so would be great if you could adjust the batch size so that there are less than 10000 batches :]

Arguments:
  • x (ArrayLike):
  • output_dir (str):
  • prefix (str, optional):
  • extension (str, optional): Without the first .. Defaults to st.
  • key (str, optional): The key to use when saving the file. Batches are stored in safetensor files, which are essentially dictionaries. This arg specifies which key contains the data of interest. Cannot be "_idx".
  • batch_size (int, optional): Defaults to $256$.
  • tqdm_style (TqdmStyle, optional): Progress bar style.
DATASET_PRESETS_CONFIGURATIONS = {'cats_vs_dogs': {'dataset_name': 'microsoft/cats_vs_dogs', 'fit_split': 'train[:80%]', 'val_split': 'train[80%:]', 'test_split': 'train', 'label_key': 'labels'}, 'cifar10': {'dataset_name': 'cifar10', 'fit_split': 'train[:80%]', 'val_split': 'train[80%:]', 'test_split': 'test', 'label_key': 'label'}, 'cifar100': {'dataset_name': 'cifar100', 'fit_split': 'train[:80%]', 'val_split': 'train[80%:]', 'test_split': 'test', 'label_key': 'fine_label'}, 'eurosat-rgb': {'dataset_name': 'timm/eurosat-rgb', 'fit_split': 'train', 'val_split': 'validation', 'test_split': 'test', 'label_key': 'label'}, 'fashion_mnist': {'dataset_name': 'zalando-datasets/fashion_mnist', 'fit_split': 'train[:80%]', 'val_split': 'train[80%:]', 'test_split': 'test', 'label_key': 'label'}, 'food101': {'dataset_name': 'ethz/food101', 'fit_split': 'train[:80%]', 'val_split': 'train[80%:]', 'test_split': 'validation', 'label_key': 'label'}, 'imagenet-1k': {'dataset_name': 'ILSVRC/imagenet-1k', 'fit_split': 'train[:80%]', 'val_split': 'train[80%:]', 'test_split': 'validation', 'label_key': 'label'}, 'mnist': {'dataset_name': 'ylecun/mnist', 'fit_split': 'train[:80%]', 'val_split': 'train[80%:]', 'test_split': 'test', 'label_key': 'label'}, 'oxford-iiit-pet': {'dataset_name': 'timm/oxford-iiit-pet', 'fit_split': 'train[:80%]', 'val_split': 'train[80%:]', 'test_split': 'test', 'label_key': 'label'}, 'resisc45': {'dataset_name': 'timm/resisc45', 'fit_split': 'train', 'val_split': 'validation', 'test_split': 'test', 'label_key': 'label'}}
DEFAULT_DATALOADER_KWARGS = {'batch_size': 1024, 'num_workers': 16, 'persistent_workers': True, 'pin_memory': False, 'shuffle': False, 'drop_last': False}
def dl_head( dl: torch.utils.data.dataloader.DataLoader, n: int) -> list[dict[str, torch.Tensor]]:
14def dl_head(dl: DataLoader, n: int) -> list[Batch]:
15    """
16    Returns the first `n` samples of a DataLoader **as a list of batches**.
17
18    Args:
19        dl (DataLoader):
20        n (int):
21
22    Warning:
23        Only supports batches that are dicts of tensors.
24    """
25
26    def _n() -> int:
27        if not batches:
28            return 0
29        k = list(batches[0].keys())[0]
30        return sum(map(lambda b: len(b[k]), batches))
31
32    batches: list[Batch] = []
33    it = iter(dl)
34    try:
35        while _n() < n:
36            batches.append(next(it))
37    except StopIteration:
38        logging.warning(
39            "Tried to extract {} elements from dataset but only found {}",
40            n,
41            _n(),
42        )
43    if (r := _n() - n) > 0:
44        for k in batches[-1].keys():
45            batches[-1][k] = batches[-1][k][:-r]
46    return batches

Returns the first n samples of a DataLoader as a list of batches.

Arguments:
  • dl (DataLoader):
  • n (int):
Warning:

Only supports batches that are dicts of tensors.

def flatten_batches(batches: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
49def flatten_batches(batches: list[Batch]) -> Batch:
50    """
51    Flattens a list of batches into a single batch.
52
53    Args:
54        batches (list[Batch]):
55    """
56    return {
57        k: torch.concat([b[k] for b in batches]) for k in batches[0].keys()
58    }

Flattens a list of batches into a single batch.

Arguments:
  • batches (list[Batch]):
def get_dataset( dataset_name: str, image_processor: Union[str, Callable, NoneType] = None, batch_size: int | None = None, num_workers: int | None = None, **kwargs: Any) -> tuple[HuggingFaceDataset, dict[str, typing.Any]]:
 95def get_dataset(
 96    dataset_name: str,
 97    image_processor: str | Callable | None = None,
 98    batch_size: int | None = None,
 99    num_workers: int | None = None,
100    **kwargs: Any,
101) -> tuple[HuggingFaceDataset, dict[str, Any]]:
102    """
103    Returns a `HuggingFaceDataset` instance from the `dataset_name` key in the
104    `DATASET_PRESETS_CONFIGURATIONS` dictionary.
105
106    Example:
107
108        >>> ds, _ = get_dataset(
109        ...     "cifar10",
110        ...     "microsoft/resnet-18",
111        ...     batch_size=64,
112        ...     num_workers=4,
113        ... )
114
115    Args:
116        dataset_name (str): See `DATASET_PRESETS_CONFIGURATIONS`
117        image_processor (str | Callable | None): The image processor to use. If
118            a str is provided, it is assumed to be a model name, and the image
119            processor will be constructed accordingly using
120            `lcc.classifiers.get_classifier_cls` and
121            `lcc.classifiers.BaseClassifier.get_image_processor`.
122        batch_size (int | None): If provided, will be added to the dataloader
123            parameters for all dataset splits (train, val, test, and predict).
124        num_workers (int | None): If provided, will be added to the dataloader
125            parameters for all dataset splits (train, val, test, and predict).
126        **kwargs: Passed to the `HuggingFaceDataset` constructor
127
128    Returns:
129        The `HuggingFaceDataset` instance and the configuration dictionary
130    """
131    config = DATASET_PRESETS_CONFIGURATIONS[dataset_name]
132    if isinstance(image_processor, str):
133        from ..classifiers import get_classifier_cls
134
135        cls = get_classifier_cls(image_processor)
136        config["image_processor"] = cls.get_image_processor(image_processor)
137    else:
138        config["image_processor"] = image_processor
139    dl_kw = {}
140    if batch_size is not None:
141        dl_kw["batch_size"] = batch_size
142    if num_workers is not None:
143        dl_kw["num_workers"] = num_workers
144    if dl_kw:
145        config["train_dl_kwargs"] = dl_kw
146        config["val_dl_kwargs"] = dl_kw.copy()
147        config["test_dl_kwargs"] = dl_kw.copy()
148        config["predict_dl_kwargs"] = dl_kw.copy()
149    config.update(kwargs)
150    return HuggingFaceDataset(**config), config

Returns a HuggingFaceDataset instance from the dataset_name key in the DATASET_PRESETS_CONFIGURATIONS dictionary.

Example:
>>> ds, _ = get_dataset(
...     "cifar10",
...     "microsoft/resnet-18",
...     batch_size=64,
...     num_workers=4,
... )
Arguments:
  • dataset_name (str): See DATASET_PRESETS_CONFIGURATIONS
  • image_processor (str | Callable | None): The image processor to use. If a str is provided, it is assumed to be a model name, and the image processor will be constructed accordingly using lcc.classifiers.get_classifier_cls and lcc.classifiers.BaseClassifier.get_image_processor.
  • batch_size (int | None): If provided, will be added to the dataloader parameters for all dataset splits (train, val, test, and predict).
  • num_workers (int | None): If provided, will be added to the dataloader parameters for all dataset splits (train, val, test, and predict).
  • **kwargs: Passed to the HuggingFaceDataset constructor
Returns:

The HuggingFaceDataset instance and the configuration dictionary

class HuggingFaceDataset(lcc.datasets.WrappedDataset):
 26class HuggingFaceDataset(WrappedDataset):
 27    """
 28    A Hugging Face image classification dataset wrapped inside a
 29    `lcc.datasets.WrappedDataset`, which is itself a
 30    [`LightningDataModule`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningDataModule.html).
 31
 32    Hugging Face image datasets are dict datasets where the image is a PIL image
 33    object. Here, images are converted to tensors using the `image_processor`
 34    (if provided), which brings this closer to the torchvision API. In this
 35    case, load and call Hugging Face models directly. If you do not provide an
 36    `image_processor`, then it is recommended that you use a Hugging Face
 37    pipeline instead.
 38
 39    Since Hugging Face datasets are dict datasets, batches are dicts of tensors
 40    (see the Hugging Face dataset hub for the list of keys).
 41    `HuggingFaceDataset` adds an extra key `_idx` that has the index of the
 42    samples in the dataset.
 43
 44    See also:
 45        https://huggingface.co/datasets?task_categories=task_categories:image-classification
 46    """
 47
 48    label_key: str
 49
 50    _datasets: dict[str, Dataset] = {}  # Cache
 51
 52    def __init__(
 53        self,
 54        dataset_name: str,
 55        fit_split: str = "training",
 56        val_split: str = "validation",
 57        test_split: str | None = None,
 58        predict_split: str | None = None,
 59        image_processor: Callable | None = None,
 60        train_dl_kwargs: dict[str, Any] | None = None,
 61        val_dl_kwargs: dict[str, Any] | None = None,
 62        test_dl_kwargs: dict[str, Any] | None = None,
 63        predict_dl_kwargs: dict[str, Any] | None = None,
 64        cache_dir: Path | str = DEFAULT_CACHE_DIR,
 65        classes: ArrayLike | None = None,
 66        label_key: str = "label",
 67    ) -> None:
 68        """
 69        Args:
 70            dataset_name (str): Name of the Hugging Face image classification
 71                dataset, as in the [Hugging Face dataset
 72                hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification).
 73            fit_split (str, optional): Name of the split containing the
 74                training data. See also
 75                https://huggingface.co/docs/datasets/en/loading#slice-splits
 76            val_split (str, optional): Name of the split containing the
 77                validation data.
 78            test_split (str | None, optional): Name of the split containing the
 79                test data. If left to `None`, setting up this datamodule at the
 80                `test` stage will raise a `RuntimeError`
 81            predict_split (str | None, optional): Name of the split containing
 82                the prediction samples. If left to `None`, setting up this
 83                datamodule at the `predict` stage will raise a `RuntimeError`
 84            image_processor (Callable | None, optional): train_dl_kwargs
 85            (dict[str, Any] | None, optional): Dataloader
 86                parameters. See also https://pytorch.org/docs/stable/data.html.
 87            val_dl_kwargs (dict[str, Any] | None, optional): Dataloader
 88                parameters.
 89            test_dl_kwargs (dict[str, Any] | None, optional): Dataloader
 90                parameters.
 91            predict_dl_kwargs (dict[str, Any] | None, optional): Dataloader
 92                parameters.
 93            cache_dir (Path | str, optional): Path to the cache directory, where
 94                Hugging Face `datasets` package will download the dataset files.
 95                This should not be a temporary directory and be consistent
 96                between runs.
 97            classes (ArrayLike | None, optional): List of
 98                classes to keep. For example if `classes=[1, 2]`, only those
 99                samples whose label is `1` or `2` will be present in the
100                dataset. If `None`, all classes are kept. Note that this only
101                applies to the `train` and `val` splits, the `test` and
102                `predict` splits (if they exist) will not be filtered.
103            label_key (str, optional): Name of the column containing the
104                label. Only relevant if `classes` is not `None`.
105        """
106
107        classes = classes if classes is None else to_int_array(classes)
108
109        def ds_split_factory(
110            split: str, apply_filter: bool = True
111        ) -> Callable[[], Dataset]:
112            """
113            Returns a function that loads the dataset split.
114
115            Args:
116                split (str): Name of the split, see constructor arguments
117                apply_filter (bool, optional): If `True` and the constructor
118                    has `classes` set, then the dataset is filtered to only
119                    keep those samples whose label is in `classes`. You
120                    probably want to set this to `False` for the prediction
121                    split and maybe even the test split.
122
123            Returns:
124                Callable[[], Dataset]: The dataset factory
125            """
126
127            def wrapped() -> Dataset:
128                ds = load_dataset(
129                    dataset_name,
130                    split=split,
131                    cache_dir=str(cache_dir),
132                    trust_remote_code=True,
133                )
134                if image_processor is not None:
135                    ds.set_transform(image_processor)
136                ds = ds.add_column("_idx", range(len(ds)))
137                if (
138                    apply_filter
139                    and isinstance(classes, np.ndarray)
140                    and len(classes) > 0
141                ):
142                    ds = ds.filter(
143                        lambda lbl: lbl in classes, input_columns=label_key
144                    )
145                return ds
146
147            return wrapped
148
149        super().__init__(
150            train=ds_split_factory(fit_split),
151            val=ds_split_factory(val_split),
152            test=(ds_split_factory(test_split) if test_split else None),
153            predict=(
154                ds_split_factory(predict_split, False)
155                if predict_split
156                else None
157            ),
158            train_dl_kwargs=train_dl_kwargs,
159            val_dl_kwargs=val_dl_kwargs,
160            test_dl_kwargs=test_dl_kwargs,
161            predict_dl_kwargs=predict_dl_kwargs,
162        )
163        self.label_key = label_key
164
165    def __len__(self) -> int:
166        """
167        Returns the size of the train split. See `HuggingFaceDataset.size`.
168        """
169        return self.size("train")
170
171    def _get_dataset(self, split: Literal["train", "val", "test"]) -> Dataset:
172        if split not in self._datasets:
173            if split == "train":
174                factory = self.train
175            elif split == "val":
176                factory = self.val
177            elif split == "test":
178                factory = self.test
179            else:
180                raise ValueError(f"Unknown split: {split}")
181            assert callable(factory)
182            self._datasets[split] = factory()
183        return self._datasets[split]
184
185    def n_classes(
186        self, split: Literal["train", "val", "test"] = "train"
187    ) -> int:
188        """
189        Returns the number of classes in a given split.
190
191        Args:
192            split (Literal["train", "val", "test"], optional): Not the true name
193                of the split (as specified on the dataset's HuggingFace page),
194                just either `train`, `val`, or `test`. Defaults to `train`.
195        """
196        ds = self._get_dataset(split)
197        return len(ds.unique(self.label_key))
198
199    def size(self, split: Literal["train", "val", "test"] = "train") -> int:
200        """
201        Returns the number of samples in a given split. If the split hasn't been
202        loaded, this will load it.
203
204        Args:
205            split (Literal["train", "val", "test"], optional): Not the true name
206                of the split (as specified on the dataset's HuggingFace page),
207                just either `train`, `val`, or `test`. Defaults to `train`.
208        """
209        return len(self._get_dataset(split))
210
211    def y_true(
212        self, split: Literal["train", "val", "test"] = "train"
213    ) -> Tensor:
214        """
215        Gets the vector of true labels of a given split.
216
217        Args:
218            split (Literal["train", "val", "test"], optional): Not the true name
219                of the split (as specified on the dataset's HuggingFace page),
220                just either `train`, `val`, or `test`. Defaults to `train`.
221
222        Returns:
223            An `int` tensor
224        """
225        return Tensor(self._get_dataset(split)[self.label_key]).int()

A Hugging Face image classification dataset wrapped inside a lcc.datasets.WrappedDataset, which is itself a LightningDataModule.

Hugging Face image datasets are dict datasets where the image is a PIL image object. Here, images are converted to tensors using the image_processor (if provided), which brings this closer to the torchvision API. In this case, load and call Hugging Face models directly. If you do not provide an image_processor, then it is recommended that you use a Hugging Face pipeline instead.

Since Hugging Face datasets are dict datasets, batches are dicts of tensors (see the Hugging Face dataset hub for the list of keys). HuggingFaceDataset adds an extra key _idx that has the index of the samples in the dataset.

See also:

https://huggingface.co/datasets?task_categories=task_categories:image-classification

HuggingFaceDataset( dataset_name: str, fit_split: str = 'training', val_split: str = 'validation', test_split: str | None = None, predict_split: str | None = None, image_processor: Optional[Callable] = None, train_dl_kwargs: dict[str, typing.Any] | None = None, val_dl_kwargs: dict[str, typing.Any] | None = None, test_dl_kwargs: dict[str, typing.Any] | None = None, predict_dl_kwargs: dict[str, typing.Any] | None = None, cache_dir: pathlib.Path | str = PosixPath('/Users/runner/.cache/huggingface/datasets'), classes: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]], NoneType] = None, label_key: str = 'label')
 52    def __init__(
 53        self,
 54        dataset_name: str,
 55        fit_split: str = "training",
 56        val_split: str = "validation",
 57        test_split: str | None = None,
 58        predict_split: str | None = None,
 59        image_processor: Callable | None = None,
 60        train_dl_kwargs: dict[str, Any] | None = None,
 61        val_dl_kwargs: dict[str, Any] | None = None,
 62        test_dl_kwargs: dict[str, Any] | None = None,
 63        predict_dl_kwargs: dict[str, Any] | None = None,
 64        cache_dir: Path | str = DEFAULT_CACHE_DIR,
 65        classes: ArrayLike | None = None,
 66        label_key: str = "label",
 67    ) -> None:
 68        """
 69        Args:
 70            dataset_name (str): Name of the Hugging Face image classification
 71                dataset, as in the [Hugging Face dataset
 72                hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification).
 73            fit_split (str, optional): Name of the split containing the
 74                training data. See also
 75                https://huggingface.co/docs/datasets/en/loading#slice-splits
 76            val_split (str, optional): Name of the split containing the
 77                validation data.
 78            test_split (str | None, optional): Name of the split containing the
 79                test data. If left to `None`, setting up this datamodule at the
 80                `test` stage will raise a `RuntimeError`
 81            predict_split (str | None, optional): Name of the split containing
 82                the prediction samples. If left to `None`, setting up this
 83                datamodule at the `predict` stage will raise a `RuntimeError`
 84            image_processor (Callable | None, optional): train_dl_kwargs
 85            (dict[str, Any] | None, optional): Dataloader
 86                parameters. See also https://pytorch.org/docs/stable/data.html.
 87            val_dl_kwargs (dict[str, Any] | None, optional): Dataloader
 88                parameters.
 89            test_dl_kwargs (dict[str, Any] | None, optional): Dataloader
 90                parameters.
 91            predict_dl_kwargs (dict[str, Any] | None, optional): Dataloader
 92                parameters.
 93            cache_dir (Path | str, optional): Path to the cache directory, where
 94                Hugging Face `datasets` package will download the dataset files.
 95                This should not be a temporary directory and be consistent
 96                between runs.
 97            classes (ArrayLike | None, optional): List of
 98                classes to keep. For example if `classes=[1, 2]`, only those
 99                samples whose label is `1` or `2` will be present in the
100                dataset. If `None`, all classes are kept. Note that this only
101                applies to the `train` and `val` splits, the `test` and
102                `predict` splits (if they exist) will not be filtered.
103            label_key (str, optional): Name of the column containing the
104                label. Only relevant if `classes` is not `None`.
105        """
106
107        classes = classes if classes is None else to_int_array(classes)
108
109        def ds_split_factory(
110            split: str, apply_filter: bool = True
111        ) -> Callable[[], Dataset]:
112            """
113            Returns a function that loads the dataset split.
114
115            Args:
116                split (str): Name of the split, see constructor arguments
117                apply_filter (bool, optional): If `True` and the constructor
118                    has `classes` set, then the dataset is filtered to only
119                    keep those samples whose label is in `classes`. You
120                    probably want to set this to `False` for the prediction
121                    split and maybe even the test split.
122
123            Returns:
124                Callable[[], Dataset]: The dataset factory
125            """
126
127            def wrapped() -> Dataset:
128                ds = load_dataset(
129                    dataset_name,
130                    split=split,
131                    cache_dir=str(cache_dir),
132                    trust_remote_code=True,
133                )
134                if image_processor is not None:
135                    ds.set_transform(image_processor)
136                ds = ds.add_column("_idx", range(len(ds)))
137                if (
138                    apply_filter
139                    and isinstance(classes, np.ndarray)
140                    and len(classes) > 0
141                ):
142                    ds = ds.filter(
143                        lambda lbl: lbl in classes, input_columns=label_key
144                    )
145                return ds
146
147            return wrapped
148
149        super().__init__(
150            train=ds_split_factory(fit_split),
151            val=ds_split_factory(val_split),
152            test=(ds_split_factory(test_split) if test_split else None),
153            predict=(
154                ds_split_factory(predict_split, False)
155                if predict_split
156                else None
157            ),
158            train_dl_kwargs=train_dl_kwargs,
159            val_dl_kwargs=val_dl_kwargs,
160            test_dl_kwargs=test_dl_kwargs,
161            predict_dl_kwargs=predict_dl_kwargs,
162        )
163        self.label_key = label_key
Arguments:
  • dataset_name (str): Name of the Hugging Face image classification dataset, as in the Hugging Face dataset hub.
  • fit_split (str, optional): Name of the split containing the training data. See also https://huggingface.co/docs/datasets/en/loading#slice-splits
  • val_split (str, optional): Name of the split containing the validation data.
  • test_split (str | None, optional): Name of the split containing the test data. If left to None, setting up this datamodule at the test stage will raise a RuntimeError
  • predict_split (str | None, optional): Name of the split containing the prediction samples. If left to None, setting up this datamodule at the predict stage will raise a RuntimeError
  • image_processor (Callable | None, optional): train_dl_kwargs
  • (dict[str, Any] | None, optional): Dataloader parameters. See also https://pytorch.org/docs/stable/data.html.
  • val_dl_kwargs (dict[str, Any] | None, optional): Dataloader parameters.
  • test_dl_kwargs (dict[str, Any] | None, optional): Dataloader parameters.
  • predict_dl_kwargs (dict[str, Any] | None, optional): Dataloader parameters.
  • cache_dir (Path | str, optional): Path to the cache directory, where Hugging Face datasets package will download the dataset files. This should not be a temporary directory and be consistent between runs.
  • classes (ArrayLike | None, optional): List of classes to keep. For example if classes=[1, 2], only those samples whose label is 1 or 2 will be present in the dataset. If None, all classes are kept. Note that this only applies to the train and val splits, the test and predict splits (if they exist) will not be filtered.
  • label_key (str, optional): Name of the column containing the label. Only relevant if classes is not None.
label_key: str
def n_classes(self, split: Literal['train', 'val', 'test'] = 'train') -> int:
185    def n_classes(
186        self, split: Literal["train", "val", "test"] = "train"
187    ) -> int:
188        """
189        Returns the number of classes in a given split.
190
191        Args:
192            split (Literal["train", "val", "test"], optional): Not the true name
193                of the split (as specified on the dataset's HuggingFace page),
194                just either `train`, `val`, or `test`. Defaults to `train`.
195        """
196        ds = self._get_dataset(split)
197        return len(ds.unique(self.label_key))

Returns the number of classes in a given split.

Arguments:
  • split (Literal["train", "val", "test"], optional): Not the true name of the split (as specified on the dataset's HuggingFace page), just either train, val, or test. Defaults to train.
def size(self, split: Literal['train', 'val', 'test'] = 'train') -> int:
199    def size(self, split: Literal["train", "val", "test"] = "train") -> int:
200        """
201        Returns the number of samples in a given split. If the split hasn't been
202        loaded, this will load it.
203
204        Args:
205            split (Literal["train", "val", "test"], optional): Not the true name
206                of the split (as specified on the dataset's HuggingFace page),
207                just either `train`, `val`, or `test`. Defaults to `train`.
208        """
209        return len(self._get_dataset(split))

Returns the number of samples in a given split. If the split hasn't been loaded, this will load it.

Arguments:
  • split (Literal["train", "val", "test"], optional): Not the true name of the split (as specified on the dataset's HuggingFace page), just either train, val, or test. Defaults to train.
def y_true(self, split: Literal['train', 'val', 'test'] = 'train') -> torch.Tensor:
211    def y_true(
212        self, split: Literal["train", "val", "test"] = "train"
213    ) -> Tensor:
214        """
215        Gets the vector of true labels of a given split.
216
217        Args:
218            split (Literal["train", "val", "test"], optional): Not the true name
219                of the split (as specified on the dataset's HuggingFace page),
220                just either `train`, `val`, or `test`. Defaults to `train`.
221
222        Returns:
223            An `int` tensor
224        """
225        return Tensor(self._get_dataset(split)[self.label_key]).int()

Gets the vector of true labels of a given split.

Arguments:
  • split (Literal["train", "val", "test"], optional): Not the true name of the split (as specified on the dataset's HuggingFace page), just either train, val, or test. Defaults to train.
Returns:

An int tensor

class WrappedDataset(pytorch_lightning.core.datamodule.LightningDataModule):
 37class WrappedDataset(pl.LightningDataModule):
 38    """
 39    A dataset wrapped inside a
 40    [`LightningDataModule`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningDataModule.html)
 41    """
 42
 43    train: DatasetOrDatasetFactory
 44    val: DatasetOrDatasetFactory
 45    test: DatasetOrDatasetFactory | None
 46    predict: DatasetOrDatasetFactory | None
 47    train_dl_kwargs: dict[str, Any]
 48    val_dl_kwargs: dict[str, Any]
 49    test_dl_kwargs: dict[str, Any]
 50    predict_dl_kwargs: dict[str, Any]
 51
 52    _prepared: bool = False
 53
 54    def __init__(
 55        self,
 56        train: DatasetOrDatasetFactory,
 57        val: DatasetOrDatasetFactory | None = None,
 58        test: DatasetOrDatasetFactory | None = None,
 59        predict: DatasetOrDatasetFactory | None = None,
 60        train_dl_kwargs: dict[str, Any] | None = None,
 61        val_dl_kwargs: dict[str, Any] | None = None,
 62        test_dl_kwargs: dict[str, Any] | None = None,
 63        predict_dl_kwargs: dict[str, Any] | None = None,
 64    ) -> None:
 65        """
 66        Args:
 67            train (DatasetOrDatasetFactory): A dataset or dataloader for
 68                training. Can be a callable without argument that returns said
 69                dataset or dataloader. In this case, it will be called during
 70                the preparation phase (so only on rank 0 in the case of
 71                multi-device or multi-node training).
 72            val (DatasetOrDatasetFactory | None, optional): Defaults to
 73                whatever was passed to the `train` argument
 74            test (DatasetOrDatasetFactory | None, optional): Defaults to
 75                `None`
 76            predict (DatasetOrDatasetFactory | None, optional): Defaults
 77                to `None`
 78            train_dl_kwargs (dict[str, Any] | None, optional): If
 79                `train` is a dataset or callable that return a dataset, this
 80                dictionary will be passed to the dataloader constructor.
 81                Defaults to `lcc.datasets.DEFAULT_DATALOADER_KWARGS`.
 82            val_dl_kwargs (dict[str, Any] | None, optional):
 83                Analogous to `train_dl_kwargs`, but defaults to (a copy of)
 84                `train_dl_kwargs` instead of
 85                `lcc.datasets.DEFAULT_DATALOADER_KWARGS`.
 86            test_dl_kwargs (dict[str, Any] | None, optional):
 87                Analogous to `train_dl_kwargs`, but defaults to (a copy of)
 88                `train_dl_kwargs` instead of
 89                `lcc.datasets.DEFAULT_DATALOADER_KWARGS`.
 90            predict_dl_kwargs (dict[str, Any] | None, optional):
 91                Analogous to `train_dl_kwargs`, but defaults to (a copy of)
 92                `train_dl_kwargs` instead of
 93                `lcc.datasets.DEFAULT_DATALOADER_KWARGS`.
 94        """
 95        super().__init__()
 96        self.train, self.val = train, val if val is not None else train
 97        self.test, self.predict = test, predict
 98        self.train_dl_kwargs = (
 99            train_dl_kwargs or DEFAULT_DATALOADER_KWARGS.copy()
100        )
101        self.val_dl_kwargs = val_dl_kwargs or self.train_dl_kwargs.copy()
102        self.test_dl_kwargs = test_dl_kwargs or self.train_dl_kwargs.copy()
103        self.predict_dl_kwargs = (
104            predict_dl_kwargs or self.train_dl_kwargs.copy()
105        )
106
107    def get_dataloader(
108        self, split: Literal["train", "val", "test", "predict"]
109    ) -> DataLoader:
110        """
111        Get a dataloader by name. The usual methods `train_dataloader` etc. call
112        this. Make sure you called `prepare_data` before calling this.
113        """
114        if split == "train":
115            obj, kw = self.train, self.train_dl_kwargs
116        elif split == "val":
117            obj, kw = self.val, self.val_dl_kwargs
118        elif split == "test":
119            if self.test is None:
120                raise RuntimeError(
121                    "Cannot get test dataloader: no test dataset or dataset "
122                    "factory has not been specified"
123                )
124            obj, kw = self.test, self.test_dl_kwargs
125        elif split == "predict":
126            if self.predict is None:
127                raise RuntimeError(
128                    "Cannot get prediction dataloader: no prediction dataset "
129                    "or dataset factory has not been specified"
130                )
131            obj, kw = self.predict, self.predict_dl_kwargs
132        else:
133            raise ValueError(f"Unknown split: '{split}'")
134        obj = obj() if callable(obj) else obj
135        return DataLoader(obj, **kw)
136
137    def predict_dataloader(self) -> DataLoader:
138        """
139        Self-explanatory. Make sure you called `prepare_data` before calling
140        this.
141        """
142        return self.get_dataloader("predict")
143
144    def prepare_data(self) -> None:
145        """
146        Overrides
147        [pl.LightningDataModule.prepare_data](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data).
148        This is automatically called so don't worry about it.
149        """
150        if self._prepared:
151            return
152        if callable(self.train):
153            r0_debug("Preparing the training dataset/split")
154            self.train = self.train()
155        if callable(self.val):
156            r0_debug("Preparing the validation dataset/split")
157            self.val = self.val()
158        if callable(self.test):
159            r0_debug("Preparing the testing dataset/split")
160            self.test = self.test()
161        if callable(self.predict):
162            r0_debug("Preparing the prediction dataset/split")
163            self.predict = self.predict()
164        self._prepared = True
165
166    def test_dataloader(self) -> DataLoader:
167        """
168        Self-explanatory. Make sure you called `prepare_data` before calling
169        this.
170        """
171        return self.get_dataloader("test")
172
173    def train_dataloader(self) -> DataLoader:
174        """
175        Self-explanatory. Make sure you called `prepare_data` before calling
176        this.
177        """
178        return self.get_dataloader("train")
179
180    def val_dataloader(self) -> DataLoader:
181        """
182        Self-explanatory. Make sure you called `prepare_data` before calling
183        this.
184        """
185        return self.get_dataloader("val")

A dataset wrapped inside a LightningDataModule

WrappedDataset( train: Union[Callable[[], torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset, Callable[[], datasets.arrow_dataset.Dataset], datasets.arrow_dataset.Dataset], val: Union[Callable[[], torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset, Callable[[], datasets.arrow_dataset.Dataset], datasets.arrow_dataset.Dataset, NoneType] = None, test: Union[Callable[[], torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset, Callable[[], datasets.arrow_dataset.Dataset], datasets.arrow_dataset.Dataset, NoneType] = None, predict: Union[Callable[[], torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset, Callable[[], datasets.arrow_dataset.Dataset], datasets.arrow_dataset.Dataset, NoneType] = None, train_dl_kwargs: dict[str, typing.Any] | None = None, val_dl_kwargs: dict[str, typing.Any] | None = None, test_dl_kwargs: dict[str, typing.Any] | None = None, predict_dl_kwargs: dict[str, typing.Any] | None = None)
 54    def __init__(
 55        self,
 56        train: DatasetOrDatasetFactory,
 57        val: DatasetOrDatasetFactory | None = None,
 58        test: DatasetOrDatasetFactory | None = None,
 59        predict: DatasetOrDatasetFactory | None = None,
 60        train_dl_kwargs: dict[str, Any] | None = None,
 61        val_dl_kwargs: dict[str, Any] | None = None,
 62        test_dl_kwargs: dict[str, Any] | None = None,
 63        predict_dl_kwargs: dict[str, Any] | None = None,
 64    ) -> None:
 65        """
 66        Args:
 67            train (DatasetOrDatasetFactory): A dataset or dataloader for
 68                training. Can be a callable without argument that returns said
 69                dataset or dataloader. In this case, it will be called during
 70                the preparation phase (so only on rank 0 in the case of
 71                multi-device or multi-node training).
 72            val (DatasetOrDatasetFactory | None, optional): Defaults to
 73                whatever was passed to the `train` argument
 74            test (DatasetOrDatasetFactory | None, optional): Defaults to
 75                `None`
 76            predict (DatasetOrDatasetFactory | None, optional): Defaults
 77                to `None`
 78            train_dl_kwargs (dict[str, Any] | None, optional): If
 79                `train` is a dataset or callable that return a dataset, this
 80                dictionary will be passed to the dataloader constructor.
 81                Defaults to `lcc.datasets.DEFAULT_DATALOADER_KWARGS`.
 82            val_dl_kwargs (dict[str, Any] | None, optional):
 83                Analogous to `train_dl_kwargs`, but defaults to (a copy of)
 84                `train_dl_kwargs` instead of
 85                `lcc.datasets.DEFAULT_DATALOADER_KWARGS`.
 86            test_dl_kwargs (dict[str, Any] | None, optional):
 87                Analogous to `train_dl_kwargs`, but defaults to (a copy of)
 88                `train_dl_kwargs` instead of
 89                `lcc.datasets.DEFAULT_DATALOADER_KWARGS`.
 90            predict_dl_kwargs (dict[str, Any] | None, optional):
 91                Analogous to `train_dl_kwargs`, but defaults to (a copy of)
 92                `train_dl_kwargs` instead of
 93                `lcc.datasets.DEFAULT_DATALOADER_KWARGS`.
 94        """
 95        super().__init__()
 96        self.train, self.val = train, val if val is not None else train
 97        self.test, self.predict = test, predict
 98        self.train_dl_kwargs = (
 99            train_dl_kwargs or DEFAULT_DATALOADER_KWARGS.copy()
100        )
101        self.val_dl_kwargs = val_dl_kwargs or self.train_dl_kwargs.copy()
102        self.test_dl_kwargs = test_dl_kwargs or self.train_dl_kwargs.copy()
103        self.predict_dl_kwargs = (
104            predict_dl_kwargs or self.train_dl_kwargs.copy()
105        )
Arguments:
  • train (DatasetOrDatasetFactory): A dataset or dataloader for training. Can be a callable without argument that returns said dataset or dataloader. In this case, it will be called during the preparation phase (so only on rank 0 in the case of multi-device or multi-node training).
  • val (DatasetOrDatasetFactory | None, optional): Defaults to whatever was passed to the train argument
  • test (DatasetOrDatasetFactory | None, optional): Defaults to None
  • predict (DatasetOrDatasetFactory | None, optional): Defaults to None
  • train_dl_kwargs (dict[str, Any] | None, optional): If train is a dataset or callable that return a dataset, this dictionary will be passed to the dataloader constructor. Defaults to lcc.datasets.DEFAULT_DATALOADER_KWARGS.
  • val_dl_kwargs (dict[str, Any] | None, optional): Analogous to train_dl_kwargs, but defaults to (a copy of) train_dl_kwargs instead of lcc.datasets.DEFAULT_DATALOADER_KWARGS.
  • test_dl_kwargs (dict[str, Any] | None, optional): Analogous to train_dl_kwargs, but defaults to (a copy of) train_dl_kwargs instead of lcc.datasets.DEFAULT_DATALOADER_KWARGS.
  • predict_dl_kwargs (dict[str, Any] | None, optional): Analogous to train_dl_kwargs, but defaults to (a copy of) train_dl_kwargs instead of lcc.datasets.DEFAULT_DATALOADER_KWARGS.
train: Union[Callable[[], torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset, Callable[[], datasets.arrow_dataset.Dataset], datasets.arrow_dataset.Dataset]
val: Union[Callable[[], torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset, Callable[[], datasets.arrow_dataset.Dataset], datasets.arrow_dataset.Dataset]
test: Union[Callable[[], torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset, Callable[[], datasets.arrow_dataset.Dataset], datasets.arrow_dataset.Dataset, NoneType]
predict: Union[Callable[[], torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset, Callable[[], datasets.arrow_dataset.Dataset], datasets.arrow_dataset.Dataset, NoneType]
train_dl_kwargs: dict[str, typing.Any]
val_dl_kwargs: dict[str, typing.Any]
test_dl_kwargs: dict[str, typing.Any]
predict_dl_kwargs: dict[str, typing.Any]
def get_dataloader( self, split: Literal['train', 'val', 'test', 'predict']) -> torch.utils.data.dataloader.DataLoader:
107    def get_dataloader(
108        self, split: Literal["train", "val", "test", "predict"]
109    ) -> DataLoader:
110        """
111        Get a dataloader by name. The usual methods `train_dataloader` etc. call
112        this. Make sure you called `prepare_data` before calling this.
113        """
114        if split == "train":
115            obj, kw = self.train, self.train_dl_kwargs
116        elif split == "val":
117            obj, kw = self.val, self.val_dl_kwargs
118        elif split == "test":
119            if self.test is None:
120                raise RuntimeError(
121                    "Cannot get test dataloader: no test dataset or dataset "
122                    "factory has not been specified"
123                )
124            obj, kw = self.test, self.test_dl_kwargs
125        elif split == "predict":
126            if self.predict is None:
127                raise RuntimeError(
128                    "Cannot get prediction dataloader: no prediction dataset "
129                    "or dataset factory has not been specified"
130                )
131            obj, kw = self.predict, self.predict_dl_kwargs
132        else:
133            raise ValueError(f"Unknown split: '{split}'")
134        obj = obj() if callable(obj) else obj
135        return DataLoader(obj, **kw)

Get a dataloader by name. The usual methods train_dataloader etc. call this. Make sure you called prepare_data before calling this.

def predict_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
137    def predict_dataloader(self) -> DataLoader:
138        """
139        Self-explanatory. Make sure you called `prepare_data` before calling
140        this.
141        """
142        return self.get_dataloader("predict")

Self-explanatory. Make sure you called prepare_data before calling this.

def prepare_data(self) -> None:
144    def prepare_data(self) -> None:
145        """
146        Overrides
147        [pl.LightningDataModule.prepare_data](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data).
148        This is automatically called so don't worry about it.
149        """
150        if self._prepared:
151            return
152        if callable(self.train):
153            r0_debug("Preparing the training dataset/split")
154            self.train = self.train()
155        if callable(self.val):
156            r0_debug("Preparing the validation dataset/split")
157            self.val = self.val()
158        if callable(self.test):
159            r0_debug("Preparing the testing dataset/split")
160            self.test = self.test()
161        if callable(self.predict):
162            r0_debug("Preparing the prediction dataset/split")
163            self.predict = self.predict()
164        self._prepared = True

Overrides pl.LightningDataModule.prepare_data. This is automatically called so don't worry about it.

def test_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
166    def test_dataloader(self) -> DataLoader:
167        """
168        Self-explanatory. Make sure you called `prepare_data` before calling
169        this.
170        """
171        return self.get_dataloader("test")

Self-explanatory. Make sure you called prepare_data before calling this.

def train_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
173    def train_dataloader(self) -> DataLoader:
174        """
175        Self-explanatory. Make sure you called `prepare_data` before calling
176        this.
177        """
178        return self.get_dataloader("train")

Self-explanatory. Make sure you called prepare_data before calling this.

def val_dataloader(self) -> torch.utils.data.dataloader.DataLoader:
180    def val_dataloader(self) -> DataLoader:
181        """
182        Self-explanatory. Make sure you called `prepare_data` before calling
183        this.
184        """
185        return self.get_dataloader("val")

Self-explanatory. Make sure you called prepare_data before calling this.