lcc.datasets
Custom Lightning datamodules
1"""Custom Lightning datamodules""" 2 3from .batched_tensor import BatchedTensorDataset 4from .huggingface import HuggingFaceDataset 5from .preset import DATASET_PRESETS_CONFIGURATIONS, get_dataset 6from .utils import dl_head, flatten_batches 7from .wrapped import DEFAULT_DATALOADER_KWARGS, WrappedDataset 8 9__all__ = [ 10 "BatchedTensorDataset", 11 "DATASET_PRESETS_CONFIGURATIONS", 12 "DEFAULT_DATALOADER_KWARGS", 13 "dl_head", 14 "flatten_batches", 15 "get_dataset", 16 "HuggingFaceDataset", 17 "WrappedDataset", 18]
42class BatchedTensorDataset(IterableDataset): 43 """ 44 Dataset that load tensor batches produced by 45 `lcc.datasets.save_tensor_batched`. 46 """ 47 48 key: str 49 paths: list[Path] 50 51 _len: int | None = None 52 """Cached length appriximation, see __len__.""" 53 54 def __init__( 55 self, 56 batch_dir: str | Path, 57 prefix: str = "batch", 58 extension: str = "st", 59 key: str = "", 60 ): 61 """ 62 See `lcc.datasets.save_tensor_batched` for the precise meaning of the 63 argument. But in a few words, this dataset will load batches from 64 [Safetensors](https://huggingface.co/docs/safetensors/index) files named 65 after the following scheme: 66 67 batch_dir/<prefix>.<unique_id>.<extension> 68 69 Safetensor files are essentially dictionaries, and each batch file is 70 expected to contain: 71 * `key` (see below): some `(B, ...)` data tensor. 72 * `"_idx"`: a `(B,)` integer tensor. 73 74 When iterating over this dataset, it will yield `(data, idx)` pairs, 75 where `data` is a row of the dataset, and `idx` is an int (as a `(,)` 76 int tensor). 77 78 Warning: 79 The list of batch files will be globed from `batch_dir` upon 80 instantiation of this dataset class. It will not be updated 81 afterwards. 82 83 Warning: 84 Batches are loaded in the order they are found in the filesystem. 85 Don't expect this order to be the same as the order in which the 86 data has been generated. 87 88 Args: 89 batch_dir (str | Path): 90 prefix (str, optional): 91 extension (str, optional): Without the first `.`. Defaults to `st`. 92 key (str, optional): The key to use when saving the file. Batches 93 are stored in safetensor files, which are essentially 94 dictionaries. This arg specifies which key contains the data of 95 interest. Cannot be `"_idx"`. 96 """ 97 if key == "_idx": 98 raise ValueError("Key cannot be '_idx'.") 99 self.paths = list(Path(batch_dir).glob(f"{prefix}.*.{extension}")) 100 self.key = key 101 102 def __iter__(self) -> Iterator[tuple[Tensor, Tensor]]: 103 for path in self.paths: 104 data = st.load_file(path) 105 for z, i in zip(data[self.key], data["_idx"]): 106 yield z, i 107 108 def __len__(self) -> int: 109 """ 110 Returns an **approximation** of the length of this dataset. Loads the 111 first batch and multiplies its length by the number of batch files. 112 """ 113 if self._len is None: 114 self._len = sum(len(st.load_file(p)["_idx"]) for p in self.paths) 115 return self._len 116 117 def distribute( 118 self, strategy: Strategy | Fabric | None 119 ) -> "BatchedTensorDataset": 120 """ 121 Creates a *copy* of a subset of this dataset so that every rank has a 122 different subset. Does not modify the current dataset. 123 124 If the strategy is not a `ParallelStrategy` or a Lightning `Fabric`, or 125 if the world size is less than 2, this method returns `self` (NOT a copy 126 of `self`). 127 """ 128 if ( 129 not isinstance(strategy, (ParallelStrategy, Fabric)) 130 or strategy.world_size < 2 131 ): 132 return self 133 ws, gr = strategy.world_size, strategy.global_rank 134 ds = deepcopy(self) 135 ds.paths, ds._len = ds.paths[gr::ws], None 136 return ds 137 138 def extract_idx( 139 self, tqdm_style: TqdmStyle = None 140 ) -> tuple[IterableDataset, Tensor]: 141 """ 142 Splits this dataset in two. The first one yeilds the data, the second 143 one yields the indices. Then, the index dataset is unrolled into a 144 single index tensor (which therefore has shape `(N,)`, where `N` is the 145 shape of the dataset). 146 """ 147 a, b = _ProjectionDataset(self, 0), _ProjectionDataset(self, 1) 148 dl = DataLoader(b, batch_size=1024, num_workers=1) 149 dl = make_tqdm(tqdm_style)(dl, "Extracting indices") 150 # TODO: setting num_workers to > 1 makes the index tensor n_workers 151 # times too long... problem with tqdm? 152 return a, torch.cat(list(dl), dim=0) 153 154 def load( 155 self, 156 batch_size: int = 256, 157 num_workers: int = 0, 158 tqdm_style: TqdmStyle = None, 159 ) -> tuple[Tensor, Tensor]: 160 """ 161 Loads a batched tensor in one go. See `BatchedTensorDataset.save`. 162 163 Args: 164 batch_size (int, optional): Defaults to 256. Does not impact the 165 actual result. 166 num_workers (int, optional): Defaults to 0, meaning single-process 167 data loading. 168 tqdm_style (TqdmStyle, 169 optional): 170 171 Returns: 172 A tuple `(data, idx)` where `data` is a `(N, ...)` tensor and `idx` 173 is `(N,)` int tensor. 174 """ 175 dl = DataLoader(self, batch_size=batch_size, num_workers=num_workers) 176 u = make_tqdm(tqdm_style)(dl, "Loading") 177 v = list(zip(*u)) 178 return torch.cat(v[0], dim=0), torch.cat(v[1], dim=0) 179 180 @staticmethod 181 def save( 182 x: ArrayLike, 183 output_dir: str | Path, 184 prefix: str = "batch", 185 extension: str = "st", 186 key: str = "", 187 batch_size: int = 256, 188 tqdm_style: TqdmStyle = None, 189 ) -> None: 190 """ 191 Saves a tensor in batches of `batch_size` elements. The files will be 192 named 193 194 output_dir/<prefix>.<batch_idx>.<extension> 195 196 The batches are saved using 197 [Safetensors](https://huggingface.co/docs/safetensors/index). 198 Safetensors files are essentially dictionaries, and each batch file is 199 structured as follows: 200 * `key` (see below): some `(batch_size, ...)` slice from `x`, 201 * `"_idx"`: a `(batch_size,)` integer tensor containing the indices in 202 `x`. 203 204 The `batch_idx` string is 4 digits long, so would be great if you could 205 adjust the batch size so that there are less than 10000 batches :] 206 207 Args: 208 x (ArrayLike): 209 output_dir (str): 210 prefix (str, optional): 211 extension (str, optional): Without the first `.`. Defaults to `st`. 212 key (str, optional): The key to use when saving the file. Batches 213 are stored in safetensor files, which are essentially 214 dictionaries. This arg specifies which key contains the data of 215 interest. Cannot be `"_idx"`. 216 batch_size (int, optional): Defaults to $256$. 217 tqdm_style (TqdmStyle, 218 optional): Progress bar style. 219 """ 220 batches = to_tensor(x).split(batch_size) 221 for i, batch in enumerate(make_tqdm(tqdm_style)(batches, "Saving")): 222 data = { 223 key: batch, 224 "_idx": torch.arange(i * batch_size, (i + 1) * batch_size), 225 } 226 st.save_file( 227 data, Path(output_dir) / f"{prefix}.{i:04}.{extension}" 228 )
Dataset that load tensor batches produced by
lcc.datasets.save_tensor_batched
.
54 def __init__( 55 self, 56 batch_dir: str | Path, 57 prefix: str = "batch", 58 extension: str = "st", 59 key: str = "", 60 ): 61 """ 62 See `lcc.datasets.save_tensor_batched` for the precise meaning of the 63 argument. But in a few words, this dataset will load batches from 64 [Safetensors](https://huggingface.co/docs/safetensors/index) files named 65 after the following scheme: 66 67 batch_dir/<prefix>.<unique_id>.<extension> 68 69 Safetensor files are essentially dictionaries, and each batch file is 70 expected to contain: 71 * `key` (see below): some `(B, ...)` data tensor. 72 * `"_idx"`: a `(B,)` integer tensor. 73 74 When iterating over this dataset, it will yield `(data, idx)` pairs, 75 where `data` is a row of the dataset, and `idx` is an int (as a `(,)` 76 int tensor). 77 78 Warning: 79 The list of batch files will be globed from `batch_dir` upon 80 instantiation of this dataset class. It will not be updated 81 afterwards. 82 83 Warning: 84 Batches are loaded in the order they are found in the filesystem. 85 Don't expect this order to be the same as the order in which the 86 data has been generated. 87 88 Args: 89 batch_dir (str | Path): 90 prefix (str, optional): 91 extension (str, optional): Without the first `.`. Defaults to `st`. 92 key (str, optional): The key to use when saving the file. Batches 93 are stored in safetensor files, which are essentially 94 dictionaries. This arg specifies which key contains the data of 95 interest. Cannot be `"_idx"`. 96 """ 97 if key == "_idx": 98 raise ValueError("Key cannot be '_idx'.") 99 self.paths = list(Path(batch_dir).glob(f"{prefix}.*.{extension}")) 100 self.key = key
See lcc.datasets.save_tensor_batched
for the precise meaning of the
argument. But in a few words, this dataset will load batches from
Safetensors files named
after the following scheme:
batch_dir/<prefix>.<unique_id>.<extension>
Safetensor files are essentially dictionaries, and each batch file is expected to contain:
key
(see below): some(B, ...)
data tensor."_idx"
: a(B,)
integer tensor.
When iterating over this dataset, it will yield (data, idx)
pairs,
where data
is a row of the dataset, and idx
is an int (as a (,)
int tensor).
Warning:
The list of batch files will be globed from
batch_dir
upon instantiation of this dataset class. It will not be updated afterwards.
Warning:
Batches are loaded in the order they are found in the filesystem. Don't expect this order to be the same as the order in which the data has been generated.
Arguments:
- batch_dir (str | Path):
- prefix (str, optional):
- extension (str, optional): Without the first
.
. Defaults tost
. - key (str, optional): The key to use when saving the file. Batches
are stored in safetensor files, which are essentially
dictionaries. This arg specifies which key contains the data of
interest. Cannot be
"_idx"
.
117 def distribute( 118 self, strategy: Strategy | Fabric | None 119 ) -> "BatchedTensorDataset": 120 """ 121 Creates a *copy* of a subset of this dataset so that every rank has a 122 different subset. Does not modify the current dataset. 123 124 If the strategy is not a `ParallelStrategy` or a Lightning `Fabric`, or 125 if the world size is less than 2, this method returns `self` (NOT a copy 126 of `self`). 127 """ 128 if ( 129 not isinstance(strategy, (ParallelStrategy, Fabric)) 130 or strategy.world_size < 2 131 ): 132 return self 133 ws, gr = strategy.world_size, strategy.global_rank 134 ds = deepcopy(self) 135 ds.paths, ds._len = ds.paths[gr::ws], None 136 return ds
Creates a copy of a subset of this dataset so that every rank has a different subset. Does not modify the current dataset.
If the strategy is not a ParallelStrategy
or a Lightning Fabric
, or
if the world size is less than 2, this method returns self
(NOT a copy
of self
).
138 def extract_idx( 139 self, tqdm_style: TqdmStyle = None 140 ) -> tuple[IterableDataset, Tensor]: 141 """ 142 Splits this dataset in two. The first one yeilds the data, the second 143 one yields the indices. Then, the index dataset is unrolled into a 144 single index tensor (which therefore has shape `(N,)`, where `N` is the 145 shape of the dataset). 146 """ 147 a, b = _ProjectionDataset(self, 0), _ProjectionDataset(self, 1) 148 dl = DataLoader(b, batch_size=1024, num_workers=1) 149 dl = make_tqdm(tqdm_style)(dl, "Extracting indices") 150 # TODO: setting num_workers to > 1 makes the index tensor n_workers 151 # times too long... problem with tqdm? 152 return a, torch.cat(list(dl), dim=0)
Splits this dataset in two. The first one yeilds the data, the second
one yields the indices. Then, the index dataset is unrolled into a
single index tensor (which therefore has shape (N,)
, where N
is the
shape of the dataset).
154 def load( 155 self, 156 batch_size: int = 256, 157 num_workers: int = 0, 158 tqdm_style: TqdmStyle = None, 159 ) -> tuple[Tensor, Tensor]: 160 """ 161 Loads a batched tensor in one go. See `BatchedTensorDataset.save`. 162 163 Args: 164 batch_size (int, optional): Defaults to 256. Does not impact the 165 actual result. 166 num_workers (int, optional): Defaults to 0, meaning single-process 167 data loading. 168 tqdm_style (TqdmStyle, 169 optional): 170 171 Returns: 172 A tuple `(data, idx)` where `data` is a `(N, ...)` tensor and `idx` 173 is `(N,)` int tensor. 174 """ 175 dl = DataLoader(self, batch_size=batch_size, num_workers=num_workers) 176 u = make_tqdm(tqdm_style)(dl, "Loading") 177 v = list(zip(*u)) 178 return torch.cat(v[0], dim=0), torch.cat(v[1], dim=0)
Loads a batched tensor in one go. See BatchedTensorDataset.save
.
Arguments:
- batch_size (int, optional): Defaults to 256. Does not impact the actual result.
- num_workers (int, optional): Defaults to 0, meaning single-process data loading.
- tqdm_style (TqdmStyle, optional):
Returns:
A tuple
(data, idx)
wheredata
is a(N, ...)
tensor andidx
is(N,)
int tensor.
180 @staticmethod 181 def save( 182 x: ArrayLike, 183 output_dir: str | Path, 184 prefix: str = "batch", 185 extension: str = "st", 186 key: str = "", 187 batch_size: int = 256, 188 tqdm_style: TqdmStyle = None, 189 ) -> None: 190 """ 191 Saves a tensor in batches of `batch_size` elements. The files will be 192 named 193 194 output_dir/<prefix>.<batch_idx>.<extension> 195 196 The batches are saved using 197 [Safetensors](https://huggingface.co/docs/safetensors/index). 198 Safetensors files are essentially dictionaries, and each batch file is 199 structured as follows: 200 * `key` (see below): some `(batch_size, ...)` slice from `x`, 201 * `"_idx"`: a `(batch_size,)` integer tensor containing the indices in 202 `x`. 203 204 The `batch_idx` string is 4 digits long, so would be great if you could 205 adjust the batch size so that there are less than 10000 batches :] 206 207 Args: 208 x (ArrayLike): 209 output_dir (str): 210 prefix (str, optional): 211 extension (str, optional): Without the first `.`. Defaults to `st`. 212 key (str, optional): The key to use when saving the file. Batches 213 are stored in safetensor files, which are essentially 214 dictionaries. This arg specifies which key contains the data of 215 interest. Cannot be `"_idx"`. 216 batch_size (int, optional): Defaults to $256$. 217 tqdm_style (TqdmStyle, 218 optional): Progress bar style. 219 """ 220 batches = to_tensor(x).split(batch_size) 221 for i, batch in enumerate(make_tqdm(tqdm_style)(batches, "Saving")): 222 data = { 223 key: batch, 224 "_idx": torch.arange(i * batch_size, (i + 1) * batch_size), 225 } 226 st.save_file( 227 data, Path(output_dir) / f"{prefix}.{i:04}.{extension}" 228 )
Saves a tensor in batches of batch_size
elements. The files will be
named
output_dir/<prefix>.<batch_idx>.<extension>
The batches are saved using Safetensors. Safetensors files are essentially dictionaries, and each batch file is structured as follows:
key
(see below): some(batch_size, ...)
slice fromx
,"_idx"
: a(batch_size,)
integer tensor containing the indices inx
.
The batch_idx
string is 4 digits long, so would be great if you could
adjust the batch size so that there are less than 10000 batches :]
Arguments:
- x (ArrayLike):
- output_dir (str):
- prefix (str, optional):
- extension (str, optional): Without the first
.
. Defaults tost
. - key (str, optional): The key to use when saving the file. Batches
are stored in safetensor files, which are essentially
dictionaries. This arg specifies which key contains the data of
interest. Cannot be
"_idx"
. - batch_size (int, optional): Defaults to $256$.
- tqdm_style (TqdmStyle, optional): Progress bar style.
14def dl_head(dl: DataLoader, n: int) -> list[Batch]: 15 """ 16 Returns the first `n` samples of a DataLoader **as a list of batches**. 17 18 Args: 19 dl (DataLoader): 20 n (int): 21 22 Warning: 23 Only supports batches that are dicts of tensors. 24 """ 25 26 def _n() -> int: 27 if not batches: 28 return 0 29 k = list(batches[0].keys())[0] 30 return sum(map(lambda b: len(b[k]), batches)) 31 32 batches: list[Batch] = [] 33 it = iter(dl) 34 try: 35 while _n() < n: 36 batches.append(next(it)) 37 except StopIteration: 38 logging.warning( 39 "Tried to extract {} elements from dataset but only found {}", 40 n, 41 _n(), 42 ) 43 if (r := _n() - n) > 0: 44 for k in batches[-1].keys(): 45 batches[-1][k] = batches[-1][k][:-r] 46 return batches
Returns the first n
samples of a DataLoader as a list of batches.
Arguments:
- dl (DataLoader):
- n (int):
Warning:
Only supports batches that are dicts of tensors.
49def flatten_batches(batches: list[Batch]) -> Batch: 50 """ 51 Flattens a list of batches into a single batch. 52 53 Args: 54 batches (list[Batch]): 55 """ 56 return { 57 k: torch.concat([b[k] for b in batches]) for k in batches[0].keys() 58 }
Flattens a list of batches into a single batch.
Arguments:
- batches (list[Batch]):
95def get_dataset( 96 dataset_name: str, 97 image_processor: str | Callable | None = None, 98 batch_size: int | None = None, 99 num_workers: int | None = None, 100 **kwargs: Any, 101) -> tuple[HuggingFaceDataset, dict[str, Any]]: 102 """ 103 Returns a `HuggingFaceDataset` instance from the `dataset_name` key in the 104 `DATASET_PRESETS_CONFIGURATIONS` dictionary. 105 106 Example: 107 108 >>> ds, _ = get_dataset( 109 ... "cifar10", 110 ... "microsoft/resnet-18", 111 ... batch_size=64, 112 ... num_workers=4, 113 ... ) 114 115 Args: 116 dataset_name (str): See `DATASET_PRESETS_CONFIGURATIONS` 117 image_processor (str | Callable | None): The image processor to use. If 118 a str is provided, it is assumed to be a model name, and the image 119 processor will be constructed accordingly using 120 `lcc.classifiers.get_classifier_cls` and 121 `lcc.classifiers.BaseClassifier.get_image_processor`. 122 batch_size (int | None): If provided, will be added to the dataloader 123 parameters for all dataset splits (train, val, test, and predict). 124 num_workers (int | None): If provided, will be added to the dataloader 125 parameters for all dataset splits (train, val, test, and predict). 126 **kwargs: Passed to the `HuggingFaceDataset` constructor 127 128 Returns: 129 The `HuggingFaceDataset` instance and the configuration dictionary 130 """ 131 config = DATASET_PRESETS_CONFIGURATIONS[dataset_name] 132 if isinstance(image_processor, str): 133 from ..classifiers import get_classifier_cls 134 135 cls = get_classifier_cls(image_processor) 136 config["image_processor"] = cls.get_image_processor(image_processor) 137 else: 138 config["image_processor"] = image_processor 139 dl_kw = {} 140 if batch_size is not None: 141 dl_kw["batch_size"] = batch_size 142 if num_workers is not None: 143 dl_kw["num_workers"] = num_workers 144 if dl_kw: 145 config["train_dl_kwargs"] = dl_kw 146 config["val_dl_kwargs"] = dl_kw.copy() 147 config["test_dl_kwargs"] = dl_kw.copy() 148 config["predict_dl_kwargs"] = dl_kw.copy() 149 config.update(kwargs) 150 return HuggingFaceDataset(**config), config
Returns a HuggingFaceDataset
instance from the dataset_name
key in the
DATASET_PRESETS_CONFIGURATIONS
dictionary.
Example:
>>> ds, _ = get_dataset( ... "cifar10", ... "microsoft/resnet-18", ... batch_size=64, ... num_workers=4, ... )
Arguments:
- dataset_name (str): See
DATASET_PRESETS_CONFIGURATIONS
- image_processor (str | Callable | None): The image processor to use. If
a str is provided, it is assumed to be a model name, and the image
processor will be constructed accordingly using
lcc.classifiers.get_classifier_cls
andlcc.classifiers.BaseClassifier.get_image_processor
. - batch_size (int | None): If provided, will be added to the dataloader parameters for all dataset splits (train, val, test, and predict).
- num_workers (int | None): If provided, will be added to the dataloader parameters for all dataset splits (train, val, test, and predict).
- **kwargs: Passed to the
HuggingFaceDataset
constructor
Returns:
The
HuggingFaceDataset
instance and the configuration dictionary
26class HuggingFaceDataset(WrappedDataset): 27 """ 28 A Hugging Face image classification dataset wrapped inside a 29 `lcc.datasets.WrappedDataset`, which is itself a 30 [`LightningDataModule`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningDataModule.html). 31 32 Hugging Face image datasets are dict datasets where the image is a PIL image 33 object. Here, images are converted to tensors using the `image_processor` 34 (if provided), which brings this closer to the torchvision API. In this 35 case, load and call Hugging Face models directly. If you do not provide an 36 `image_processor`, then it is recommended that you use a Hugging Face 37 pipeline instead. 38 39 Since Hugging Face datasets are dict datasets, batches are dicts of tensors 40 (see the Hugging Face dataset hub for the list of keys). 41 `HuggingFaceDataset` adds an extra key `_idx` that has the index of the 42 samples in the dataset. 43 44 See also: 45 https://huggingface.co/datasets?task_categories=task_categories:image-classification 46 """ 47 48 label_key: str 49 50 _datasets: dict[str, Dataset] = {} # Cache 51 52 def __init__( 53 self, 54 dataset_name: str, 55 fit_split: str = "training", 56 val_split: str = "validation", 57 test_split: str | None = None, 58 predict_split: str | None = None, 59 image_processor: Callable | None = None, 60 train_dl_kwargs: dict[str, Any] | None = None, 61 val_dl_kwargs: dict[str, Any] | None = None, 62 test_dl_kwargs: dict[str, Any] | None = None, 63 predict_dl_kwargs: dict[str, Any] | None = None, 64 cache_dir: Path | str = DEFAULT_CACHE_DIR, 65 classes: ArrayLike | None = None, 66 label_key: str = "label", 67 ) -> None: 68 """ 69 Args: 70 dataset_name (str): Name of the Hugging Face image classification 71 dataset, as in the [Hugging Face dataset 72 hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification). 73 fit_split (str, optional): Name of the split containing the 74 training data. See also 75 https://huggingface.co/docs/datasets/en/loading#slice-splits 76 val_split (str, optional): Name of the split containing the 77 validation data. 78 test_split (str | None, optional): Name of the split containing the 79 test data. If left to `None`, setting up this datamodule at the 80 `test` stage will raise a `RuntimeError` 81 predict_split (str | None, optional): Name of the split containing 82 the prediction samples. If left to `None`, setting up this 83 datamodule at the `predict` stage will raise a `RuntimeError` 84 image_processor (Callable | None, optional): train_dl_kwargs 85 (dict[str, Any] | None, optional): Dataloader 86 parameters. See also https://pytorch.org/docs/stable/data.html. 87 val_dl_kwargs (dict[str, Any] | None, optional): Dataloader 88 parameters. 89 test_dl_kwargs (dict[str, Any] | None, optional): Dataloader 90 parameters. 91 predict_dl_kwargs (dict[str, Any] | None, optional): Dataloader 92 parameters. 93 cache_dir (Path | str, optional): Path to the cache directory, where 94 Hugging Face `datasets` package will download the dataset files. 95 This should not be a temporary directory and be consistent 96 between runs. 97 classes (ArrayLike | None, optional): List of 98 classes to keep. For example if `classes=[1, 2]`, only those 99 samples whose label is `1` or `2` will be present in the 100 dataset. If `None`, all classes are kept. Note that this only 101 applies to the `train` and `val` splits, the `test` and 102 `predict` splits (if they exist) will not be filtered. 103 label_key (str, optional): Name of the column containing the 104 label. Only relevant if `classes` is not `None`. 105 """ 106 107 classes = classes if classes is None else to_int_array(classes) 108 109 def ds_split_factory( 110 split: str, apply_filter: bool = True 111 ) -> Callable[[], Dataset]: 112 """ 113 Returns a function that loads the dataset split. 114 115 Args: 116 split (str): Name of the split, see constructor arguments 117 apply_filter (bool, optional): If `True` and the constructor 118 has `classes` set, then the dataset is filtered to only 119 keep those samples whose label is in `classes`. You 120 probably want to set this to `False` for the prediction 121 split and maybe even the test split. 122 123 Returns: 124 Callable[[], Dataset]: The dataset factory 125 """ 126 127 def wrapped() -> Dataset: 128 ds = load_dataset( 129 dataset_name, 130 split=split, 131 cache_dir=str(cache_dir), 132 trust_remote_code=True, 133 ) 134 if image_processor is not None: 135 ds.set_transform(image_processor) 136 ds = ds.add_column("_idx", range(len(ds))) 137 if ( 138 apply_filter 139 and isinstance(classes, np.ndarray) 140 and len(classes) > 0 141 ): 142 ds = ds.filter( 143 lambda lbl: lbl in classes, input_columns=label_key 144 ) 145 return ds 146 147 return wrapped 148 149 super().__init__( 150 train=ds_split_factory(fit_split), 151 val=ds_split_factory(val_split), 152 test=(ds_split_factory(test_split) if test_split else None), 153 predict=( 154 ds_split_factory(predict_split, False) 155 if predict_split 156 else None 157 ), 158 train_dl_kwargs=train_dl_kwargs, 159 val_dl_kwargs=val_dl_kwargs, 160 test_dl_kwargs=test_dl_kwargs, 161 predict_dl_kwargs=predict_dl_kwargs, 162 ) 163 self.label_key = label_key 164 165 def __len__(self) -> int: 166 """ 167 Returns the size of the train split. See `HuggingFaceDataset.size`. 168 """ 169 return self.size("train") 170 171 def _get_dataset(self, split: Literal["train", "val", "test"]) -> Dataset: 172 if split not in self._datasets: 173 if split == "train": 174 factory = self.train 175 elif split == "val": 176 factory = self.val 177 elif split == "test": 178 factory = self.test 179 else: 180 raise ValueError(f"Unknown split: {split}") 181 assert callable(factory) 182 self._datasets[split] = factory() 183 return self._datasets[split] 184 185 def n_classes( 186 self, split: Literal["train", "val", "test"] = "train" 187 ) -> int: 188 """ 189 Returns the number of classes in a given split. 190 191 Args: 192 split (Literal["train", "val", "test"], optional): Not the true name 193 of the split (as specified on the dataset's HuggingFace page), 194 just either `train`, `val`, or `test`. Defaults to `train`. 195 """ 196 ds = self._get_dataset(split) 197 return len(ds.unique(self.label_key)) 198 199 def size(self, split: Literal["train", "val", "test"] = "train") -> int: 200 """ 201 Returns the number of samples in a given split. If the split hasn't been 202 loaded, this will load it. 203 204 Args: 205 split (Literal["train", "val", "test"], optional): Not the true name 206 of the split (as specified on the dataset's HuggingFace page), 207 just either `train`, `val`, or `test`. Defaults to `train`. 208 """ 209 return len(self._get_dataset(split)) 210 211 def y_true( 212 self, split: Literal["train", "val", "test"] = "train" 213 ) -> Tensor: 214 """ 215 Gets the vector of true labels of a given split. 216 217 Args: 218 split (Literal["train", "val", "test"], optional): Not the true name 219 of the split (as specified on the dataset's HuggingFace page), 220 just either `train`, `val`, or `test`. Defaults to `train`. 221 222 Returns: 223 An `int` tensor 224 """ 225 return Tensor(self._get_dataset(split)[self.label_key]).int()
A Hugging Face image classification dataset wrapped inside a
lcc.datasets.WrappedDataset
, which is itself a
LightningDataModule
.
Hugging Face image datasets are dict datasets where the image is a PIL image
object. Here, images are converted to tensors using the image_processor
(if provided), which brings this closer to the torchvision API. In this
case, load and call Hugging Face models directly. If you do not provide an
image_processor
, then it is recommended that you use a Hugging Face
pipeline instead.
Since Hugging Face datasets are dict datasets, batches are dicts of tensors
(see the Hugging Face dataset hub for the list of keys).
HuggingFaceDataset
adds an extra key _idx
that has the index of the
samples in the dataset.
See also:
https://huggingface.co/datasets?task_categories=task_categories:image-classification
52 def __init__( 53 self, 54 dataset_name: str, 55 fit_split: str = "training", 56 val_split: str = "validation", 57 test_split: str | None = None, 58 predict_split: str | None = None, 59 image_processor: Callable | None = None, 60 train_dl_kwargs: dict[str, Any] | None = None, 61 val_dl_kwargs: dict[str, Any] | None = None, 62 test_dl_kwargs: dict[str, Any] | None = None, 63 predict_dl_kwargs: dict[str, Any] | None = None, 64 cache_dir: Path | str = DEFAULT_CACHE_DIR, 65 classes: ArrayLike | None = None, 66 label_key: str = "label", 67 ) -> None: 68 """ 69 Args: 70 dataset_name (str): Name of the Hugging Face image classification 71 dataset, as in the [Hugging Face dataset 72 hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification). 73 fit_split (str, optional): Name of the split containing the 74 training data. See also 75 https://huggingface.co/docs/datasets/en/loading#slice-splits 76 val_split (str, optional): Name of the split containing the 77 validation data. 78 test_split (str | None, optional): Name of the split containing the 79 test data. If left to `None`, setting up this datamodule at the 80 `test` stage will raise a `RuntimeError` 81 predict_split (str | None, optional): Name of the split containing 82 the prediction samples. If left to `None`, setting up this 83 datamodule at the `predict` stage will raise a `RuntimeError` 84 image_processor (Callable | None, optional): train_dl_kwargs 85 (dict[str, Any] | None, optional): Dataloader 86 parameters. See also https://pytorch.org/docs/stable/data.html. 87 val_dl_kwargs (dict[str, Any] | None, optional): Dataloader 88 parameters. 89 test_dl_kwargs (dict[str, Any] | None, optional): Dataloader 90 parameters. 91 predict_dl_kwargs (dict[str, Any] | None, optional): Dataloader 92 parameters. 93 cache_dir (Path | str, optional): Path to the cache directory, where 94 Hugging Face `datasets` package will download the dataset files. 95 This should not be a temporary directory and be consistent 96 between runs. 97 classes (ArrayLike | None, optional): List of 98 classes to keep. For example if `classes=[1, 2]`, only those 99 samples whose label is `1` or `2` will be present in the 100 dataset. If `None`, all classes are kept. Note that this only 101 applies to the `train` and `val` splits, the `test` and 102 `predict` splits (if they exist) will not be filtered. 103 label_key (str, optional): Name of the column containing the 104 label. Only relevant if `classes` is not `None`. 105 """ 106 107 classes = classes if classes is None else to_int_array(classes) 108 109 def ds_split_factory( 110 split: str, apply_filter: bool = True 111 ) -> Callable[[], Dataset]: 112 """ 113 Returns a function that loads the dataset split. 114 115 Args: 116 split (str): Name of the split, see constructor arguments 117 apply_filter (bool, optional): If `True` and the constructor 118 has `classes` set, then the dataset is filtered to only 119 keep those samples whose label is in `classes`. You 120 probably want to set this to `False` for the prediction 121 split and maybe even the test split. 122 123 Returns: 124 Callable[[], Dataset]: The dataset factory 125 """ 126 127 def wrapped() -> Dataset: 128 ds = load_dataset( 129 dataset_name, 130 split=split, 131 cache_dir=str(cache_dir), 132 trust_remote_code=True, 133 ) 134 if image_processor is not None: 135 ds.set_transform(image_processor) 136 ds = ds.add_column("_idx", range(len(ds))) 137 if ( 138 apply_filter 139 and isinstance(classes, np.ndarray) 140 and len(classes) > 0 141 ): 142 ds = ds.filter( 143 lambda lbl: lbl in classes, input_columns=label_key 144 ) 145 return ds 146 147 return wrapped 148 149 super().__init__( 150 train=ds_split_factory(fit_split), 151 val=ds_split_factory(val_split), 152 test=(ds_split_factory(test_split) if test_split else None), 153 predict=( 154 ds_split_factory(predict_split, False) 155 if predict_split 156 else None 157 ), 158 train_dl_kwargs=train_dl_kwargs, 159 val_dl_kwargs=val_dl_kwargs, 160 test_dl_kwargs=test_dl_kwargs, 161 predict_dl_kwargs=predict_dl_kwargs, 162 ) 163 self.label_key = label_key
Arguments:
- dataset_name (str): Name of the Hugging Face image classification dataset, as in the Hugging Face dataset hub.
- fit_split (str, optional): Name of the split containing the training data. See also https://huggingface.co/docs/datasets/en/loading#slice-splits
- val_split (str, optional): Name of the split containing the validation data.
- test_split (str | None, optional): Name of the split containing the
test data. If left to
None
, setting up this datamodule at thetest
stage will raise aRuntimeError
- predict_split (str | None, optional): Name of the split containing
the prediction samples. If left to
None
, setting up this datamodule at thepredict
stage will raise aRuntimeError
- image_processor (Callable | None, optional): train_dl_kwargs
- (dict[str, Any] | None, optional): Dataloader parameters. See also https://pytorch.org/docs/stable/data.html.
- val_dl_kwargs (dict[str, Any] | None, optional): Dataloader parameters.
- test_dl_kwargs (dict[str, Any] | None, optional): Dataloader parameters.
- predict_dl_kwargs (dict[str, Any] | None, optional): Dataloader parameters.
- cache_dir (Path | str, optional): Path to the cache directory, where
Hugging Face
datasets
package will download the dataset files. This should not be a temporary directory and be consistent between runs. - classes (ArrayLike | None, optional): List of
classes to keep. For example if
classes=[1, 2]
, only those samples whose label is1
or2
will be present in the dataset. IfNone
, all classes are kept. Note that this only applies to thetrain
andval
splits, thetest
andpredict
splits (if they exist) will not be filtered. - label_key (str, optional): Name of the column containing the
label. Only relevant if
classes
is notNone
.
185 def n_classes( 186 self, split: Literal["train", "val", "test"] = "train" 187 ) -> int: 188 """ 189 Returns the number of classes in a given split. 190 191 Args: 192 split (Literal["train", "val", "test"], optional): Not the true name 193 of the split (as specified on the dataset's HuggingFace page), 194 just either `train`, `val`, or `test`. Defaults to `train`. 195 """ 196 ds = self._get_dataset(split) 197 return len(ds.unique(self.label_key))
199 def size(self, split: Literal["train", "val", "test"] = "train") -> int: 200 """ 201 Returns the number of samples in a given split. If the split hasn't been 202 loaded, this will load it. 203 204 Args: 205 split (Literal["train", "val", "test"], optional): Not the true name 206 of the split (as specified on the dataset's HuggingFace page), 207 just either `train`, `val`, or `test`. Defaults to `train`. 208 """ 209 return len(self._get_dataset(split))
211 def y_true( 212 self, split: Literal["train", "val", "test"] = "train" 213 ) -> Tensor: 214 """ 215 Gets the vector of true labels of a given split. 216 217 Args: 218 split (Literal["train", "val", "test"], optional): Not the true name 219 of the split (as specified on the dataset's HuggingFace page), 220 just either `train`, `val`, or `test`. Defaults to `train`. 221 222 Returns: 223 An `int` tensor 224 """ 225 return Tensor(self._get_dataset(split)[self.label_key]).int()
37class WrappedDataset(pl.LightningDataModule): 38 """ 39 A dataset wrapped inside a 40 [`LightningDataModule`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningDataModule.html) 41 """ 42 43 train: DatasetOrDatasetFactory 44 val: DatasetOrDatasetFactory 45 test: DatasetOrDatasetFactory | None 46 predict: DatasetOrDatasetFactory | None 47 train_dl_kwargs: dict[str, Any] 48 val_dl_kwargs: dict[str, Any] 49 test_dl_kwargs: dict[str, Any] 50 predict_dl_kwargs: dict[str, Any] 51 52 _prepared: bool = False 53 54 def __init__( 55 self, 56 train: DatasetOrDatasetFactory, 57 val: DatasetOrDatasetFactory | None = None, 58 test: DatasetOrDatasetFactory | None = None, 59 predict: DatasetOrDatasetFactory | None = None, 60 train_dl_kwargs: dict[str, Any] | None = None, 61 val_dl_kwargs: dict[str, Any] | None = None, 62 test_dl_kwargs: dict[str, Any] | None = None, 63 predict_dl_kwargs: dict[str, Any] | None = None, 64 ) -> None: 65 """ 66 Args: 67 train (DatasetOrDatasetFactory): A dataset or dataloader for 68 training. Can be a callable without argument that returns said 69 dataset or dataloader. In this case, it will be called during 70 the preparation phase (so only on rank 0 in the case of 71 multi-device or multi-node training). 72 val (DatasetOrDatasetFactory | None, optional): Defaults to 73 whatever was passed to the `train` argument 74 test (DatasetOrDatasetFactory | None, optional): Defaults to 75 `None` 76 predict (DatasetOrDatasetFactory | None, optional): Defaults 77 to `None` 78 train_dl_kwargs (dict[str, Any] | None, optional): If 79 `train` is a dataset or callable that return a dataset, this 80 dictionary will be passed to the dataloader constructor. 81 Defaults to `lcc.datasets.DEFAULT_DATALOADER_KWARGS`. 82 val_dl_kwargs (dict[str, Any] | None, optional): 83 Analogous to `train_dl_kwargs`, but defaults to (a copy of) 84 `train_dl_kwargs` instead of 85 `lcc.datasets.DEFAULT_DATALOADER_KWARGS`. 86 test_dl_kwargs (dict[str, Any] | None, optional): 87 Analogous to `train_dl_kwargs`, but defaults to (a copy of) 88 `train_dl_kwargs` instead of 89 `lcc.datasets.DEFAULT_DATALOADER_KWARGS`. 90 predict_dl_kwargs (dict[str, Any] | None, optional): 91 Analogous to `train_dl_kwargs`, but defaults to (a copy of) 92 `train_dl_kwargs` instead of 93 `lcc.datasets.DEFAULT_DATALOADER_KWARGS`. 94 """ 95 super().__init__() 96 self.train, self.val = train, val if val is not None else train 97 self.test, self.predict = test, predict 98 self.train_dl_kwargs = ( 99 train_dl_kwargs or DEFAULT_DATALOADER_KWARGS.copy() 100 ) 101 self.val_dl_kwargs = val_dl_kwargs or self.train_dl_kwargs.copy() 102 self.test_dl_kwargs = test_dl_kwargs or self.train_dl_kwargs.copy() 103 self.predict_dl_kwargs = ( 104 predict_dl_kwargs or self.train_dl_kwargs.copy() 105 ) 106 107 def get_dataloader( 108 self, split: Literal["train", "val", "test", "predict"] 109 ) -> DataLoader: 110 """ 111 Get a dataloader by name. The usual methods `train_dataloader` etc. call 112 this. Make sure you called `prepare_data` before calling this. 113 """ 114 if split == "train": 115 obj, kw = self.train, self.train_dl_kwargs 116 elif split == "val": 117 obj, kw = self.val, self.val_dl_kwargs 118 elif split == "test": 119 if self.test is None: 120 raise RuntimeError( 121 "Cannot get test dataloader: no test dataset or dataset " 122 "factory has not been specified" 123 ) 124 obj, kw = self.test, self.test_dl_kwargs 125 elif split == "predict": 126 if self.predict is None: 127 raise RuntimeError( 128 "Cannot get prediction dataloader: no prediction dataset " 129 "or dataset factory has not been specified" 130 ) 131 obj, kw = self.predict, self.predict_dl_kwargs 132 else: 133 raise ValueError(f"Unknown split: '{split}'") 134 obj = obj() if callable(obj) else obj 135 return DataLoader(obj, **kw) 136 137 def predict_dataloader(self) -> DataLoader: 138 """ 139 Self-explanatory. Make sure you called `prepare_data` before calling 140 this. 141 """ 142 return self.get_dataloader("predict") 143 144 def prepare_data(self) -> None: 145 """ 146 Overrides 147 [pl.LightningDataModule.prepare_data](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data). 148 This is automatically called so don't worry about it. 149 """ 150 if self._prepared: 151 return 152 if callable(self.train): 153 r0_debug("Preparing the training dataset/split") 154 self.train = self.train() 155 if callable(self.val): 156 r0_debug("Preparing the validation dataset/split") 157 self.val = self.val() 158 if callable(self.test): 159 r0_debug("Preparing the testing dataset/split") 160 self.test = self.test() 161 if callable(self.predict): 162 r0_debug("Preparing the prediction dataset/split") 163 self.predict = self.predict() 164 self._prepared = True 165 166 def test_dataloader(self) -> DataLoader: 167 """ 168 Self-explanatory. Make sure you called `prepare_data` before calling 169 this. 170 """ 171 return self.get_dataloader("test") 172 173 def train_dataloader(self) -> DataLoader: 174 """ 175 Self-explanatory. Make sure you called `prepare_data` before calling 176 this. 177 """ 178 return self.get_dataloader("train") 179 180 def val_dataloader(self) -> DataLoader: 181 """ 182 Self-explanatory. Make sure you called `prepare_data` before calling 183 this. 184 """ 185 return self.get_dataloader("val")
A dataset wrapped inside a
LightningDataModule
54 def __init__( 55 self, 56 train: DatasetOrDatasetFactory, 57 val: DatasetOrDatasetFactory | None = None, 58 test: DatasetOrDatasetFactory | None = None, 59 predict: DatasetOrDatasetFactory | None = None, 60 train_dl_kwargs: dict[str, Any] | None = None, 61 val_dl_kwargs: dict[str, Any] | None = None, 62 test_dl_kwargs: dict[str, Any] | None = None, 63 predict_dl_kwargs: dict[str, Any] | None = None, 64 ) -> None: 65 """ 66 Args: 67 train (DatasetOrDatasetFactory): A dataset or dataloader for 68 training. Can be a callable without argument that returns said 69 dataset or dataloader. In this case, it will be called during 70 the preparation phase (so only on rank 0 in the case of 71 multi-device or multi-node training). 72 val (DatasetOrDatasetFactory | None, optional): Defaults to 73 whatever was passed to the `train` argument 74 test (DatasetOrDatasetFactory | None, optional): Defaults to 75 `None` 76 predict (DatasetOrDatasetFactory | None, optional): Defaults 77 to `None` 78 train_dl_kwargs (dict[str, Any] | None, optional): If 79 `train` is a dataset or callable that return a dataset, this 80 dictionary will be passed to the dataloader constructor. 81 Defaults to `lcc.datasets.DEFAULT_DATALOADER_KWARGS`. 82 val_dl_kwargs (dict[str, Any] | None, optional): 83 Analogous to `train_dl_kwargs`, but defaults to (a copy of) 84 `train_dl_kwargs` instead of 85 `lcc.datasets.DEFAULT_DATALOADER_KWARGS`. 86 test_dl_kwargs (dict[str, Any] | None, optional): 87 Analogous to `train_dl_kwargs`, but defaults to (a copy of) 88 `train_dl_kwargs` instead of 89 `lcc.datasets.DEFAULT_DATALOADER_KWARGS`. 90 predict_dl_kwargs (dict[str, Any] | None, optional): 91 Analogous to `train_dl_kwargs`, but defaults to (a copy of) 92 `train_dl_kwargs` instead of 93 `lcc.datasets.DEFAULT_DATALOADER_KWARGS`. 94 """ 95 super().__init__() 96 self.train, self.val = train, val if val is not None else train 97 self.test, self.predict = test, predict 98 self.train_dl_kwargs = ( 99 train_dl_kwargs or DEFAULT_DATALOADER_KWARGS.copy() 100 ) 101 self.val_dl_kwargs = val_dl_kwargs or self.train_dl_kwargs.copy() 102 self.test_dl_kwargs = test_dl_kwargs or self.train_dl_kwargs.copy() 103 self.predict_dl_kwargs = ( 104 predict_dl_kwargs or self.train_dl_kwargs.copy() 105 )
Arguments:
- train (DatasetOrDatasetFactory): A dataset or dataloader for training. Can be a callable without argument that returns said dataset or dataloader. In this case, it will be called during the preparation phase (so only on rank 0 in the case of multi-device or multi-node training).
- val (DatasetOrDatasetFactory | None, optional): Defaults to
whatever was passed to the
train
argument - test (DatasetOrDatasetFactory | None, optional): Defaults to
None
- predict (DatasetOrDatasetFactory | None, optional): Defaults
to
None
- train_dl_kwargs (dict[str, Any] | None, optional): If
train
is a dataset or callable that return a dataset, this dictionary will be passed to the dataloader constructor. Defaults tolcc.datasets.DEFAULT_DATALOADER_KWARGS
. - val_dl_kwargs (dict[str, Any] | None, optional): Analogous to
train_dl_kwargs
, but defaults to (a copy of)train_dl_kwargs
instead oflcc.datasets.DEFAULT_DATALOADER_KWARGS
. - test_dl_kwargs (dict[str, Any] | None, optional): Analogous to
train_dl_kwargs
, but defaults to (a copy of)train_dl_kwargs
instead oflcc.datasets.DEFAULT_DATALOADER_KWARGS
. - predict_dl_kwargs (dict[str, Any] | None, optional): Analogous to
train_dl_kwargs
, but defaults to (a copy of)train_dl_kwargs
instead oflcc.datasets.DEFAULT_DATALOADER_KWARGS
.
107 def get_dataloader( 108 self, split: Literal["train", "val", "test", "predict"] 109 ) -> DataLoader: 110 """ 111 Get a dataloader by name. The usual methods `train_dataloader` etc. call 112 this. Make sure you called `prepare_data` before calling this. 113 """ 114 if split == "train": 115 obj, kw = self.train, self.train_dl_kwargs 116 elif split == "val": 117 obj, kw = self.val, self.val_dl_kwargs 118 elif split == "test": 119 if self.test is None: 120 raise RuntimeError( 121 "Cannot get test dataloader: no test dataset or dataset " 122 "factory has not been specified" 123 ) 124 obj, kw = self.test, self.test_dl_kwargs 125 elif split == "predict": 126 if self.predict is None: 127 raise RuntimeError( 128 "Cannot get prediction dataloader: no prediction dataset " 129 "or dataset factory has not been specified" 130 ) 131 obj, kw = self.predict, self.predict_dl_kwargs 132 else: 133 raise ValueError(f"Unknown split: '{split}'") 134 obj = obj() if callable(obj) else obj 135 return DataLoader(obj, **kw)
Get a dataloader by name. The usual methods train_dataloader
etc. call
this. Make sure you called prepare_data
before calling this.
137 def predict_dataloader(self) -> DataLoader: 138 """ 139 Self-explanatory. Make sure you called `prepare_data` before calling 140 this. 141 """ 142 return self.get_dataloader("predict")
Self-explanatory. Make sure you called prepare_data
before calling
this.
144 def prepare_data(self) -> None: 145 """ 146 Overrides 147 [pl.LightningDataModule.prepare_data](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data). 148 This is automatically called so don't worry about it. 149 """ 150 if self._prepared: 151 return 152 if callable(self.train): 153 r0_debug("Preparing the training dataset/split") 154 self.train = self.train() 155 if callable(self.val): 156 r0_debug("Preparing the validation dataset/split") 157 self.val = self.val() 158 if callable(self.test): 159 r0_debug("Preparing the testing dataset/split") 160 self.test = self.test() 161 if callable(self.predict): 162 r0_debug("Preparing the prediction dataset/split") 163 self.predict = self.predict() 164 self._prepared = True
Overrides pl.LightningDataModule.prepare_data. This is automatically called so don't worry about it.
166 def test_dataloader(self) -> DataLoader: 167 """ 168 Self-explanatory. Make sure you called `prepare_data` before calling 169 this. 170 """ 171 return self.get_dataloader("test")
Self-explanatory. Make sure you called prepare_data
before calling
this.
173 def train_dataloader(self) -> DataLoader: 174 """ 175 Self-explanatory. Make sure you called `prepare_data` before calling 176 this. 177 """ 178 return self.get_dataloader("train")
Self-explanatory. Make sure you called prepare_data
before calling
this.
180 def val_dataloader(self) -> DataLoader: 181 """ 182 Self-explanatory. Make sure you called `prepare_data` before calling 183 this. 184 """ 185 return self.get_dataloader("val")
Self-explanatory. Make sure you called prepare_data
before calling
this.