lcc.correction

Everything related to LCC.

 1"""Everything related to LCC."""
 2
 3from .choice import (
 4    LCC_CLASS_SELECTIONS,
 5    GraphTotallyDisconnected,
 6    LCCClassSelection,
 7    choose_classes,
 8    confusion_graph,
 9    heaviest_connected_subgraph,
10    max_connected_confusion_choice,
11    top_confusion_pairs,
12)
13from .clustering import (
14    class_otm_matching,
15    otm_matching_predicates,
16)
17from .loss import ExactLCCLoss, LCCLoss, RandomizedLCCLoss
18from .louvain import louvain_clustering
19from .peer_pressure import peer_pressure_clustering
20from .utils import Matching
21
22__all__ = [
23    "choose_classes",
24    "class_otm_matching",
25    "confusion_graph",
26    "ExactLCCLoss",
27    "GraphTotallyDisconnected",
28    "heaviest_connected_subgraph",
29    "LCC_CLASS_SELECTIONS",
30    "LCCClassSelection",
31    "LCCLoss",
32    "louvain_clustering",
33    "Matching",
34    "max_connected_confusion_choice",
35    "otm_matching_predicates",
36    "peer_pressure_clustering",
37    "RandomizedLCCLoss",
38    "top_confusion_pairs",
39]
def choose_classes( y_true: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_pred: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], policy: Union[Literal['top_pair_1', 'top_pair_5', 'top_pair_10', 'top_connected_2', 'top_connected_5', 'top_connected_10', 'max_connected'], Literal['all'], NoneType] = None) -> list[int] | None:
 95def choose_classes(
 96    y_true: ArrayLike,
 97    y_pred: ArrayLike,
 98    policy: LCCClassSelection | Literal["all"] | None = None,
 99) -> list[int] | None:
100    """
101    Given true and predicted labels, select classes whose samples should undergo
102    LCC based on some policy. See `lcc.correction.LCCClassSelection`.
103
104    For convenience, this method returns `None` if all classes should be
105    considered.
106
107    Warning:
108        When selecting a `"top_<N>"` policy, the returned list may have fewer
109        than `N` elements. For example, this happens when there are fewer than
110        `N` classes in the dataset.
111    """
112    y_true, y_pred = to_int_tensor(y_true), to_int_tensor(y_pred)
113    if policy is None:
114        return None
115    if policy == "all":
116        logging.warning(
117            "LCC class selection policy 'all' is deprecated. Use `None` instead (which has the same effect)"
118        )
119        return None
120    n_classes = y_true.unique().numel()
121    if policy.startswith("top_pair_"):
122        n = int(policy[9:])
123        if n >= n_classes:
124            return None
125        pairs = top_confusion_pairs(y_pred, y_true, n_classes, n_pairs=n)
126        return list(set(sum(pairs, ())))
127    try:
128        if policy.startswith("top_connected_"):
129            n = int(policy[14:])
130            if n >= n_classes:
131                return None
132            return max_connected_confusion_choice(
133                y_pred, y_true, n_classes, n
134            )[0]
135        return max_connected_confusion_choice(y_pred, y_true, n_classes)[0]
136    except GraphTotallyDisconnected:
137        return None

Given true and predicted labels, select classes whose samples should undergo LCC based on some policy. See lcc.correction.LCCClassSelection.

For convenience, this method returns None if all classes should be considered.

Warning:

When selecting a "top_<N>" policy, the returned list may have fewer than N elements. For example, this happens when there are fewer than N classes in the dataset.

def class_otm_matching( y_a: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_b: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]]) -> dict[int, set[int]] | dict[str, set[int]] | dict[int, set[str]] | dict[str, set[str]]:
 67def class_otm_matching(y_a: ArrayLike, y_b: ArrayLike) -> Matching:
 68    """
 69    Let `y_a` and `y_b` be `(N,)` integer arrays. We think of them as classes on
 70    some dataset, say `x`, which we call respectively *$a$-classes* and
 71    *$b$-classes*. This method performs a one-to-many matching from the classes
 72    in `y_a` to the classes in `y_b` to overall maximize the cardinality of the
 73    intersection between samples labeled by $a$ and matched $b$-classes.
 74
 75    Example:
 76
 77        >>> y_a = np.array([1, 1, 1, 1, 2, 2, 3, 4, 4])
 78        >>> y_b = np.array([10, 50, 10, 20, 20, 20, 30, 30, 30])
 79        >>> class_otm_matching(y_a, y_b)
 80        {1: {10, 50}, 2: {20}, 3: set(), 4: {30}}
 81
 82        Here, `y_a` assigns class `1` to samples 0 to 3, label `2` to samples 4
 83        and 5 etc. On the other hand, `y_b` assigns its own classes to the
 84        dataset. What is the best way to regroup classes of `y_b` to
 85        approximate the labelling of `y_a`? The `class_otm_matching` return
 86        value argues that classes `10` and `15` should be regrouped under `1`
 87        (they fit neatly), label `20` should be renamed to `2` (eventhough it
 88        "leaks" a little, in that sample 3 is labelled with `1` and `20`), and
 89        class `30` should be renamed to `4`. No class in `y_b` is assigned to
 90        class `3` in this matching.
 91
 92    Note:
 93        There are no restriction on the values of `y_a` and `y_b`. In
 94        particular, they need not be distinct: the following works fine
 95
 96        >>> y_a = np.array([1, 1, 1, 1, 2, 2, 3, 4, 4])
 97        >>> y_b = np.array([1, 5, 1, 2, 2, 2, 3, 3, 3])
 98        >>> class_otm_matching(y_a, y_b)
 99        {1: {1, 5}, 2: {2}, 3: set(), 4: {3}}
100
101    Warning:
102        Negative labels in `y_a` or `y_b` are ignored. So the output matching
103        dict will never have negative keys, and the sets will never have
104        negative values either.
105
106    Args:
107        y_a (ArrayLike): A `(N,)` integer array.
108        y_b (ArrayLike): A `(N,)` integer array.
109
110    Returns:
111        A dict that maps each class in `y_a` to the set of classes in `y_b` that
112        it has matched.
113    """
114    y_a, y_b, match_graph = to_int_array(y_a), to_int_array(y_b), nx.DiGraph()
115    for i, j in product(np.unique(y_a), np.unique(y_b)):
116        if i < 0 or j < 0:
117            continue
118        n = np.sum((y_a == i) & (y_b == j))
119        match_graph.add_edge(f"a_{i}", f"b_{j}", weight=n)
120    matching = _otm_matching(
121        match_graph,
122        [f"a_{i}" for i in np.unique(y_a)],
123        [f"b_{i}" for i in np.unique(y_b)],
124        mode="max",
125    )
126    return {
127        int(a_i.split("_")[-1]): {int(b_j.split("_")[-1]) for b_j in b_js}
128        for a_i, b_js in matching.items()
129    }

Let y_a and y_b be (N,) integer arrays. We think of them as classes on some dataset, say x, which we call respectively $a$-classes and $b$-classes. This method performs a one-to-many matching from the classes in y_a to the classes in y_b to overall maximize the cardinality of the intersection between samples labeled by $a$ and matched $b$-classes.

Example:
>>> y_a = np.array([1, 1, 1, 1, 2, 2, 3, 4, 4])
>>> y_b = np.array([10, 50, 10, 20, 20, 20, 30, 30, 30])
>>> class_otm_matching(y_a, y_b)
{1: {10, 50}, 2: {20}, 3: set(), 4: {30}}

Here, y_a assigns class 1 to samples 0 to 3, label 2 to samples 4 and 5 etc. On the other hand, y_b assigns its own classes to the dataset. What is the best way to regroup classes of y_b to approximate the labelling of y_a? The class_otm_matching return value argues that classes 10 and 15 should be regrouped under 1 (they fit neatly), label 20 should be renamed to 2 (eventhough it "leaks" a little, in that sample 3 is labelled with 1 and 20), and class 30 should be renamed to 4. No class in y_b is assigned to class 3 in this matching.

Note:

There are no restriction on the values of y_a and y_b. In particular, they need not be distinct: the following works fine

>>> y_a = np.array([1, 1, 1, 1, 2, 2, 3, 4, 4])
>>> y_b = np.array([1, 5, 1, 2, 2, 2, 3, 3, 3])
>>> class_otm_matching(y_a, y_b)
{1: {1, 5}, 2: {2}, 3: set(), 4: {3}}
Warning:

Negative labels in y_a or y_b are ignored. So the output matching dict will never have negative keys, and the sets will never have negative values either.

Arguments:
  • y_a (ArrayLike): A (N,) integer array.
  • y_b (ArrayLike): A (N,) integer array.
Returns:

A dict that maps each class in y_a to the set of classes in y_b that it has matched.

def confusion_graph( y_pred: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_true: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], n_classes: int, threshold: int = 10) -> networkx.classes.graph.Graph:
62def confusion_graph(
63    y_pred: ArrayLike, y_true: ArrayLike, n_classes: int, threshold: int = 10
64) -> nx.Graph:
65    """
66    Create a confusion graph from predicted and true labels. Two labels $a$ and
67    $b$ are confused if there is at least one sample belonging to class $a$ that
68    is predicted as class $b$, or vice versa. The nodes of the confusion graph
69    are labels, two labels are connected by an edge if they are confused by at
70    least `threshold` samples, and the edges' weight are the number of times two
71    labels are confused for each other.
72
73    Args:
74        y_pred (ArrayLike): A `(N,)` int tensor or an `(N, n_classes)`
75            probabilities/logits float tensor
76        y_true (ArrayLike): A `(N,)` int tensor
77        n_classes (int):
78        threshold (int, optional): Minimum number of times two classes must be
79            confused (in either direction) to be included in the graph.
80
81    Warning:
82        There are no loops, i.e. correct predictions are not reported in the
83        graph unlike in usual confusion matrices.
84    """
85    y_pred, y_true = to_int_tensor(y_pred), to_int_tensor(y_true)
86    cm = multiclass_confusion_matrix(y_pred, y_true, num_classes=n_classes)
87    cm = cm + cm.T  # Confusion in either direction
88    cg = nx.Graph()
89    for i, j in combinations(list(range(n_classes)), 2):
90        if (w := cm[i, j].item()) >= threshold:
91            cg.add_edge(i, j, weight=w)
92    return cg

Create a confusion graph from predicted and true labels. Two labels $a$ and $b$ are confused if there is at least one sample belonging to class $a$ that is predicted as class $b$, or vice versa. The nodes of the confusion graph are labels, two labels are connected by an edge if they are confused by at least threshold samples, and the edges' weight are the number of times two labels are confused for each other.

Arguments:
  • y_pred (ArrayLike): A (N,) int tensor or an (N, n_classes) probabilities/logits float tensor
  • y_true (ArrayLike): A (N,) int tensor
  • n_classes (int):
  • threshold (int, optional): Minimum number of times two classes must be confused (in either direction) to be included in the graph.
Warning:

There are no loops, i.e. correct predictions are not reported in the graph unlike in usual confusion matrices.

class ExactLCCLoss(lcc.correction.LCCLoss):
 29class ExactLCCLoss(LCCLoss):
 30    """LCC loss that corrects missclustered samples using their CC KNNs"""
 31
 32    k: int
 33    n_classes: int
 34    tqdm_style: TqdmStyle
 35    matching: dict[int, set[int]]
 36
 37    # ↓ i_clst -> (knn idx, tensor of all CC samples in that clst)
 38    data: dict[int, tuple[faiss.IndexFlatL2, Tensor]] = {}
 39
 40    def __call__(
 41        self, z: Tensor, y_true: ArrayLike, y_clst: ArrayLike
 42    ) -> Tensor:
 43        _, _, p_mc, _ = otm_matching_predicates(
 44            y_true, y_clst, self.matching, c_a=self.n_classes
 45        )  # p_mc: (n_classes, len(z))
 46        z = z.flatten(1)
 47        terms = []
 48        for i_true in np.unique(to_int_array(y_true)):
 49            if not p_mc[i_true].any():
 50                # No MC samples in this class and batch
 51                continue
 52            u = z[p_mc[i_true]]  # MC samples in class i_true
 53            d = []  # Distances of MC samples to candidate targets
 54            for i_clst in self.matching[i_true]:
 55                if i_clst not in self.data:
 56                    continue
 57                knn, cc = self.data[i_clst]
 58                # Find a candidate target in i_clst for each sample in u
 59                _, j = knn.search(to_array(u).astype(np.float32), self.k)
 60                j = to_int_tensor(j)  # (len(u), k)
 61                t = cc[j].mean(dim=1).to(u.device)  # (n, n_features)
 62                # Save distance to these candidate targets
 63                d.append(
 64                    torch.norm(u - t, dim=1) / sqrt(u.shape[-1])  # (len(u),)
 65                )
 66            if d:
 67                terms.append(
 68                    torch.stack(d)  # (?, len(u))
 69                    .min(dim=0)
 70                    .values  # (len(u),)
 71                )
 72        if not terms:
 73            return torch.tensor(0.0, requires_grad=True).to(z.device)
 74        return torch.cat(terms).mean()
 75
 76    def __init__(
 77        self,
 78        n_classes: int,
 79        k: int = 5,
 80        tqdm_style: TqdmStyle = None,
 81        strategy: Strategy | Fabric | None = None,
 82    ) -> None:
 83        super().__init__(strategy=strategy)
 84        self.k, self.n_classes = k, n_classes
 85        self.tqdm_style = tqdm_style
 86
 87    def sync(self, **kwargs: Any) -> None:
 88        """
 89        Remember that every rank has its own subset of cluster to manage. Before
 90        sync every rank's `self.data` only contains data pertaining to this
 91        rank's clusters.
 92
 93        This method works in two steps. First, every rank writes its data to
 94        some temporary directory. Then, every rank loads data from that
 95        directory.
 96
 97        EZPZ
 98        """
 99        if self.strategy is None:
100            return
101        path = self._get_tmp_dir()
102        gr = self.strategy.global_rank
103        st.save_file(
104            {str(i_clst): cc for i_clst, (_, cc) in self.data.items()},
105            path / f"cc.{gr}",
106        )
107        for i_clst, (idx, _) in self.data.items():
108            idx = faiss.index_gpu_to_cpu(idx)
109            faiss.write_index(idx, str(path / f"knn.{i_clst}.{gr}"))
110        self.strategy.barrier()
111        for r in range(self.strategy.world_size):
112            if r == gr:
113                continue  # data from this rank is already in self.data
114            ccs = st.load_file(path / f"cc.{r}")
115            gpu = faiss.StandardGpuResources()
116            for i_clst, cc in ccs.items():  # type: ignore
117                knn = faiss.read_index(str(path / f"knn.{i_clst}.{r}"))
118                knn = faiss.index_cpu_to_gpu(gpu, gr, knn)
119                self.data[int(i_clst)] = (knn, cc)
120        return super().sync(**kwargs)
121
122    def update(
123        self,
124        dl: DataLoader,
125        y_true: ArrayLike,
126        y_clst: ArrayLike,
127        matching: Matching,
128    ) -> None:
129        """
130        Reminder:
131            `dl` has to iterate over the whole dataset, even if this method is
132            called in a distributed environment. The labels vectors must also
133            cover the whole dataset.
134        """
135        self.matching = to_int_matching(matching)
136        y_clst = to_int_array(y_clst)
137        n_features = next(iter(dl))[0].flatten(1).shape[-1]
138        p1, p2, _, _ = otm_matching_predicates(
139            y_true, y_clst, self.matching, c_a=self.n_classes
140        )
141        p_cc = (p1 & p2).sum(axis=0).astype(bool)  # (n_samples,)
142
143        # Cluster labels that this rank has to manage
144        clsts = self._distribute_labels(y_clst)
145        # ↓ i_clst -> (knn idx, list of batches CC samples in this clst)
146        data: dict[int, tuple[faiss.IndexFlatL2, list[Tensor]]] = {
147            i_clst: (faiss.IndexFlatL2(n_features), []) for i_clst in clsts
148        }
149        if self.strategy is not None:
150            gpu = faiss.StandardGpuResources()
151            gr = self.strategy.global_rank
152            for i_clst, (knn, cc) in data.items():
153                knn = faiss.index_cpu_to_gpu(gpu, gr, knn)
154                data[i_clst] = (knn, cc)
155
156        tqdm, n_seen = make_tqdm(self.tqdm_style), 0
157        for z, *_ in tqdm(dl, f"Building {len(data)} KNN indices"):
158            z = z.flatten(1)  # (bs, n_feat.)
159            _y_clst = y_clst[n_seen : n_seen + len(z)]  # (bs,)
160            _p_cc = p_cc[n_seen : n_seen + len(z)]  # (bs,)
161            for i_clst in np.unique(_y_clst):
162                if i_clst not in data:
163                    continue  # Cluster not managed by this rank
164                # ↓ Mask for smpls in this batch that are CC and in i_clsts
165                _p_cc_i_clst = _p_cc & (_y_clst == i_clst)
166                if not _p_cc_i_clst.any():
167                    continue  # No CC sample in cluster i_clst in this batch
168                _z = z[_p_cc_i_clst]
169                data[i_clst][0].add(to_array(_z).astype(np.float32))
170                data[i_clst][1].append(_z)
171            n_seen += len(z)
172
173        # ↓ i_clst -> (knn idx, tensor of all CC samples in this clst)
174        #   IF i_clst has at least one CC sample
175        self.data = {
176            i_clst: (idx, torch.cat(lst))
177            for i_clst, (idx, lst) in data.items()
178            if lst
179        }

LCC loss that corrects missclustered samples using their CC KNNs

ExactLCCLoss( n_classes: int, k: int = 5, tqdm_style: Optional[Literal['notebook', 'console', 'none']] = None, strategy: pytorch_lightning.strategies.strategy.Strategy | lightning_fabric.fabric.Fabric | None = None)
76    def __init__(
77        self,
78        n_classes: int,
79        k: int = 5,
80        tqdm_style: TqdmStyle = None,
81        strategy: Strategy | Fabric | None = None,
82    ) -> None:
83        super().__init__(strategy=strategy)
84        self.k, self.n_classes = k, n_classes
85        self.tqdm_style = tqdm_style
k: int
n_classes: int
tqdm_style: Optional[Literal['notebook', 'console', 'none']]
matching: dict[int, set[int]]
data: dict[int, tuple[faiss.swigfaiss.IndexFlatL2, torch.Tensor]] = {}
def sync(self, **kwargs: Any) -> None:
 87    def sync(self, **kwargs: Any) -> None:
 88        """
 89        Remember that every rank has its own subset of cluster to manage. Before
 90        sync every rank's `self.data` only contains data pertaining to this
 91        rank's clusters.
 92
 93        This method works in two steps. First, every rank writes its data to
 94        some temporary directory. Then, every rank loads data from that
 95        directory.
 96
 97        EZPZ
 98        """
 99        if self.strategy is None:
100            return
101        path = self._get_tmp_dir()
102        gr = self.strategy.global_rank
103        st.save_file(
104            {str(i_clst): cc for i_clst, (_, cc) in self.data.items()},
105            path / f"cc.{gr}",
106        )
107        for i_clst, (idx, _) in self.data.items():
108            idx = faiss.index_gpu_to_cpu(idx)
109            faiss.write_index(idx, str(path / f"knn.{i_clst}.{gr}"))
110        self.strategy.barrier()
111        for r in range(self.strategy.world_size):
112            if r == gr:
113                continue  # data from this rank is already in self.data
114            ccs = st.load_file(path / f"cc.{r}")
115            gpu = faiss.StandardGpuResources()
116            for i_clst, cc in ccs.items():  # type: ignore
117                knn = faiss.read_index(str(path / f"knn.{i_clst}.{r}"))
118                knn = faiss.index_cpu_to_gpu(gpu, gr, knn)
119                self.data[int(i_clst)] = (knn, cc)
120        return super().sync(**kwargs)

Remember that every rank has its own subset of cluster to manage. Before sync every rank's self.data only contains data pertaining to this rank's clusters.

This method works in two steps. First, every rank writes its data to some temporary directory. Then, every rank loads data from that directory.

EZPZ

def update( self, dl: torch.utils.data.dataloader.DataLoader, y_true: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_clst: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], matching: dict[int, set[int]] | dict[str, set[int]] | dict[int, set[str]] | dict[str, set[str]]) -> None:
122    def update(
123        self,
124        dl: DataLoader,
125        y_true: ArrayLike,
126        y_clst: ArrayLike,
127        matching: Matching,
128    ) -> None:
129        """
130        Reminder:
131            `dl` has to iterate over the whole dataset, even if this method is
132            called in a distributed environment. The labels vectors must also
133            cover the whole dataset.
134        """
135        self.matching = to_int_matching(matching)
136        y_clst = to_int_array(y_clst)
137        n_features = next(iter(dl))[0].flatten(1).shape[-1]
138        p1, p2, _, _ = otm_matching_predicates(
139            y_true, y_clst, self.matching, c_a=self.n_classes
140        )
141        p_cc = (p1 & p2).sum(axis=0).astype(bool)  # (n_samples,)
142
143        # Cluster labels that this rank has to manage
144        clsts = self._distribute_labels(y_clst)
145        # ↓ i_clst -> (knn idx, list of batches CC samples in this clst)
146        data: dict[int, tuple[faiss.IndexFlatL2, list[Tensor]]] = {
147            i_clst: (faiss.IndexFlatL2(n_features), []) for i_clst in clsts
148        }
149        if self.strategy is not None:
150            gpu = faiss.StandardGpuResources()
151            gr = self.strategy.global_rank
152            for i_clst, (knn, cc) in data.items():
153                knn = faiss.index_cpu_to_gpu(gpu, gr, knn)
154                data[i_clst] = (knn, cc)
155
156        tqdm, n_seen = make_tqdm(self.tqdm_style), 0
157        for z, *_ in tqdm(dl, f"Building {len(data)} KNN indices"):
158            z = z.flatten(1)  # (bs, n_feat.)
159            _y_clst = y_clst[n_seen : n_seen + len(z)]  # (bs,)
160            _p_cc = p_cc[n_seen : n_seen + len(z)]  # (bs,)
161            for i_clst in np.unique(_y_clst):
162                if i_clst not in data:
163                    continue  # Cluster not managed by this rank
164                # ↓ Mask for smpls in this batch that are CC and in i_clsts
165                _p_cc_i_clst = _p_cc & (_y_clst == i_clst)
166                if not _p_cc_i_clst.any():
167                    continue  # No CC sample in cluster i_clst in this batch
168                _z = z[_p_cc_i_clst]
169                data[i_clst][0].add(to_array(_z).astype(np.float32))
170                data[i_clst][1].append(_z)
171            n_seen += len(z)
172
173        # ↓ i_clst -> (knn idx, tensor of all CC samples in this clst)
174        #   IF i_clst has at least one CC sample
175        self.data = {
176            i_clst: (idx, torch.cat(lst))
177            for i_clst, (idx, lst) in data.items()
178            if lst
179        }
Reminder:

dl has to iterate over the whole dataset, even if this method is called in a distributed environment. The labels vectors must also cover the whole dataset.

class GraphTotallyDisconnected(builtins.ValueError):
55class GraphTotallyDisconnected(ValueError):
56    """
57    Raised in `lcc.correction.heaviest_connected_subgraph` when a graph is
58    totally disconnected (has no edges).
59    """

Raised in lcc.correction.heaviest_connected_subgraph when a graph is totally disconnected (has no edges).

def heaviest_connected_subgraph( graph: networkx.classes.graph.Graph, max_size: int | None = None, strict: bool = False, key: str = 'weight') -> tuple[networkx.classes.graph.Graph, float]:
140def heaviest_connected_subgraph(
141    graph: nx.Graph,
142    max_size: int | None = None,
143    strict: bool = False,
144    key: str = "weight",
145) -> tuple[nx.Graph, float]:
146    """
147    Find the heaviest connected full subgraph of an undirected graph with
148    weighted edges. In other words, returns the connected component whose total
149    edge weight is the largest.
150
151    Under the hood, this function maintains a list of connected full subgraphs
152    and iteratively adds the heaviest edge to those subgraphs that have one if
153    its endpoints. Note that:
154    - if no graph touch the current edge, then it is added to the list as its
155      own subgraph;
156    - if a graph have both endpoints of the current edge, then the edge was
157      already part of that graph and it is not modified;
158    - graphs in the list that have already reached `max_size` are not modified.
159
160    Finally, the heaviest graph is returned.
161
162    Warning:
163        Setting `strict` to `True` can make the problem impossible, e.g. if
164        `graph` doesn't have a large enough connected component. In such cases,
165        a `RuntimeError` is raised.
166
167    Warning:
168        If the graph is totally disconnected (i.e. has no edges), then a
169        `GraphTotallyDisconnected` exception is raised, rather than returning a
170        subgraph with a single node.
171
172    Args:
173        graph (nx.Graph): Most likely the confusion graph returned by
174            `lcc.choice.confusion_graph` eh?
175        max_size (int | None, optional): If left to `None`, returns the
176            heaviest connected component.
177        strict (bool, optional): If `True`, the returned graph is guaranteed to
178            have exactly `max_size` nodes. If `False`, the returned graph may
179            have fewer (but never more) nodes.
180        key (str, optional): The edge attribute to use as weight.
181
182    Returns:
183        A connected subgraph and its total weight (see also `total_weight`).
184    """
185    if not graph.edges:
186        raise GraphTotallyDisconnected()
187    _total_weight = partial(total_weight, key=key)
188    if max_size is None:
189        subgraphs = sorted(
190            map(graph.subgraph, nx.connected_components(graph)),
191            key=_total_weight,
192            reverse=True,
193        )
194        return subgraphs[0], _total_weight(subgraphs[0])
195    edges = sorted(
196        graph.edges(data=True), key=lambda e: e[2][key], reverse=True
197    )
198    subgraphs = []
199    for u, v, _ in tqdm(edges, desc="Finding heaviest subgraph"):
200        if not subgraphs:
201            subgraphs.append(graph.subgraph([u, v]))
202            continue
203        has_u = np.array([g.has_node(u) for g in subgraphs], dtype=bool)
204        has_v = np.array([g.has_node(v) for g in subgraphs], dtype=bool)
205        if np.any(has_u & has_v):  # a graph already contains edge (u, v)
206            continue
207        p = (has_u | has_v) & (np.array(list(map(len, subgraphs))) < max_size)
208        to_extend = [g for i, g in enumerate(subgraphs) if p[i]]
209        subgraphs = [g for i, g in enumerate(subgraphs) if ~p[i]]
210        if to_extend:
211            for g in to_extend:
212                subgraphs.append(graph.subgraph(list(g.nodes) + [u, v]))
213        else:
214            subgraphs.append(graph.subgraph([u, v]))
215    if strict:
216        subgraphs = list(filter(lambda g: len(g) >= max_size, subgraphs))
217    subgraphs = sorted(
218        subgraphs,
219        key=_total_weight,
220        reverse=True,
221    )
222    if not subgraphs:
223        raise RuntimeError(
224            "Could not find heaviest subgraph with given size constraint. "
225            "Try again with strict=False."
226        )
227    return subgraphs[0], _total_weight(subgraphs[0])

Find the heaviest connected full subgraph of an undirected graph with weighted edges. In other words, returns the connected component whose total edge weight is the largest.

Under the hood, this function maintains a list of connected full subgraphs and iteratively adds the heaviest edge to those subgraphs that have one if its endpoints. Note that:

  • if no graph touch the current edge, then it is added to the list as its own subgraph;
  • if a graph have both endpoints of the current edge, then the edge was already part of that graph and it is not modified;
  • graphs in the list that have already reached max_size are not modified.

Finally, the heaviest graph is returned.

Warning:

Setting strict to True can make the problem impossible, e.g. if graph doesn't have a large enough connected component. In such cases, a RuntimeError is raised.

Warning:

If the graph is totally disconnected (i.e. has no edges), then a GraphTotallyDisconnected exception is raised, rather than returning a subgraph with a single node.

Arguments:
  • graph (nx.Graph): Most likely the confusion graph returned by lcc.choice.confusion_graph eh?
  • max_size (int | None, optional): If left to None, returns the heaviest connected component.
  • strict (bool, optional): If True, the returned graph is guaranteed to have exactly max_size nodes. If False, the returned graph may have fewer (but never more) nodes.
  • key (str, optional): The edge attribute to use as weight.
Returns:

A connected subgraph and its total weight (see also total_weight).

LCC_CLASS_SELECTIONS = ['top_pair_1', 'top_pair_5', 'top_pair_10', 'top_connected_2', 'top_connected_5', 'top_connected_10', 'max_connected']
LCCClassSelection = typing.Literal['top_pair_1', 'top_pair_5', 'top_pair_10', 'top_connected_2', 'top_connected_5', 'top_connected_10', 'max_connected']
class LCCLoss(abc.ABC):
 20class LCCLoss(ABC):
 21    """Abstract class that encapsulates a loss function for LCC."""
 22
 23    strategy: ParallelStrategy | Fabric | None
 24
 25    _tmp_dir: TemporaryDirectory | None = None
 26    """
 27    A temporary directory that can be created if needed to save/load data
 28    before/after sync. Cleaned up automatically in
 29    `on_after_sync`.
 30    """
 31
 32    @abstractmethod
 33    def __call__(
 34        self, z: Tensor, y_true: ArrayLike, y_clst: ArrayLike
 35    ) -> Tensor:
 36        pass
 37
 38    def __init__(self, strategy: Strategy | Fabric | None = None) -> None:
 39        # TODO: Log a warning if strategy is a Strategy but not a
 40        # ParallelStrategy
 41        self.strategy = (
 42            strategy
 43            if isinstance(strategy, (ParallelStrategy, Fabric))
 44            else None
 45        )
 46
 47    def _distribute_labels(self, y: ArrayLike) -> set[int]:
 48        """
 49        Given an array of labels, distributes unique labels across ranks.
 50
 51        Example:
 52            Let's say the world size is 2.
 53
 54            >>> y = np.array([0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5])
 55            >>> self._distribute_labels(y)
 56            {0, 2, 4}  # on rank 0
 57            {1, 3, 5}  # on rank 1
 58        """
 59        y = to_int_array(y)
 60        if self.strategy is not None and self.strategy.world_size > 1:
 61            ws, gr = self.strategy.world_size, self.strategy.global_rank
 62            return {i for i in np.unique(y) if i % ws == gr}
 63        return set(np.unique(y))
 64
 65    def _get_tmp_dir(self) -> Path:
 66        """
 67        On rank 0, acquires a temporary directory, broadcast the handler to all
 68        ranks, and returns its path.
 69
 70        Warning:
 71            In a distributed environment, this method must be called from all
 72            ranks "at the same time".
 73
 74        Args:
 75            broadcast (bool, optional):
 76        """
 77        if self._tmp_dir is not None:
 78            return Path(self._tmp_dir.name)
 79        if self.strategy is None:
 80            self._tmp_dir = TemporaryDirectory(prefix="lcc-")
 81            return Path(self._tmp_dir.name)
 82        if self.strategy.global_rank == 0:
 83            self._tmp_dir = TemporaryDirectory(prefix="lcc-")
 84        else:
 85            self._tmp_dir = None
 86        self._tmp_dir = self.strategy.broadcast(self._tmp_dir, src=0)
 87        assert self._tmp_dir is not None
 88        return Path(self._tmp_dir.name)
 89
 90    def on_after_sync(self, **kwargs: Any) -> None:
 91        """
 92        Should be called on all ranks after the loss object has been synced.
 93        """
 94        if self.strategy is not None:
 95            self.strategy.barrier()
 96        if self._tmp_dir is not None and (
 97            self.strategy is None or self.strategy.global_rank == 0
 98        ):
 99            self._tmp_dir.cleanup()
100
101    def on_before_sync(self, **kwargs: Any) -> None:
102        """
103        Should be called on all ranks before syncing the loss object from rank 0
104        to other ranks.
105        """
106
107    def sync(self, **kwargs: Any) -> None:
108        """
109        Distributes or shares this object's data across all ranks. Call this
110        after each ranks called `update`. Is just a barrier by default.
111        """
112        if self.strategy is not None:
113            self.strategy.barrier()
114
115    @abstractmethod
116    def update(
117        self,
118        dl: DataLoader,
119        y_true: ArrayLike,
120        y_clst: ArrayLike,
121        matching: Matching,
122    ) -> None:
123        """
124        Updates the internal state of the loss function. Presumably called at
125        the begining of each epoch where LCC is to be applied.
126
127        The dataloader has to yield batches that are tuples of tensors, the
128        first of which is a 2D tensor of samples.
129
130        Warning:
131            If the construction of the loss object is distributed across
132            multiple ranks, make sure that `dl` iterate over the WHOLE dataset
133            (no distributed sampling).
134        """

Abstract class that encapsulates a loss function for LCC.

strategy: pytorch_lightning.strategies.parallel.ParallelStrategy | lightning_fabric.fabric.Fabric | None
def on_after_sync(self, **kwargs: Any) -> None:
90    def on_after_sync(self, **kwargs: Any) -> None:
91        """
92        Should be called on all ranks after the loss object has been synced.
93        """
94        if self.strategy is not None:
95            self.strategy.barrier()
96        if self._tmp_dir is not None and (
97            self.strategy is None or self.strategy.global_rank == 0
98        ):
99            self._tmp_dir.cleanup()

Should be called on all ranks after the loss object has been synced.

def on_before_sync(self, **kwargs: Any) -> None:
101    def on_before_sync(self, **kwargs: Any) -> None:
102        """
103        Should be called on all ranks before syncing the loss object from rank 0
104        to other ranks.
105        """

Should be called on all ranks before syncing the loss object from rank 0 to other ranks.

def sync(self, **kwargs: Any) -> None:
107    def sync(self, **kwargs: Any) -> None:
108        """
109        Distributes or shares this object's data across all ranks. Call this
110        after each ranks called `update`. Is just a barrier by default.
111        """
112        if self.strategy is not None:
113            self.strategy.barrier()

Distributes or shares this object's data across all ranks. Call this after each ranks called update. Is just a barrier by default.

@abstractmethod
def update( self, dl: torch.utils.data.dataloader.DataLoader, y_true: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_clst: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], matching: dict[int, set[int]] | dict[str, set[int]] | dict[int, set[str]] | dict[str, set[str]]) -> None:
115    @abstractmethod
116    def update(
117        self,
118        dl: DataLoader,
119        y_true: ArrayLike,
120        y_clst: ArrayLike,
121        matching: Matching,
122    ) -> None:
123        """
124        Updates the internal state of the loss function. Presumably called at
125        the begining of each epoch where LCC is to be applied.
126
127        The dataloader has to yield batches that are tuples of tensors, the
128        first of which is a 2D tensor of samples.
129
130        Warning:
131            If the construction of the loss object is distributed across
132            multiple ranks, make sure that `dl` iterate over the WHOLE dataset
133            (no distributed sampling).
134        """

Updates the internal state of the loss function. Presumably called at the begining of each epoch where LCC is to be applied.

The dataloader has to yield batches that are tuples of tensors, the first of which is a 2D tensor of samples.

Warning:

If the construction of the loss object is distributed across multiple ranks, make sure that dl iterate over the WHOLE dataset (no distributed sampling).

def louvain_clustering( ds: lcc.datasets.BatchedTensorDataset, k: int, strategy: pytorch_lightning.strategies.strategy.Strategy | lightning_fabric.fabric.Fabric | None = None, n_features: int | None = None, tqdm_style: Optional[Literal['notebook', 'console', 'none']] = None, device: Any = 'cpu') -> numpy.ndarray:
49def louvain_clustering(
50    ds: BatchedTensorDataset,
51    k: int,
52    strategy: Strategy | Fabric | None = None,
53    n_features: int | None = None,
54    tqdm_style: TqdmStyle = None,
55    device: Any = "cpu",
56) -> np.ndarray:
57    """
58    Args:
59        ds (BatchedTensorDataset):
60        k (int):
61        strategy (Strategy | Fabric | None, optional): Defaults to `None`,
62            meaning that the algorithm will not be parallelized.
63        n_features (int | None, optional):
64        tqdm_style (TqdmStyle, optional):
65        device (Any, optional):
66    """
67    graph = knn_graph(
68        ds, k, strategy=strategy, n_features=n_features, tqdm_style=tqdm_style
69    )
70    communities = _louvain_or_leiden(graph, device)
71    y_clst = [0] * graph.number_of_nodes()
72    for i_clst, clst in enumerate(communities):
73        for smpl in clst:
74            y_clst[smpl] = i_clst
75    return np.array(y_clst)
Arguments:
  • ds (BatchedTensorDataset):
  • k (int):
  • strategy (Strategy | Fabric | None, optional): Defaults to None, meaning that the algorithm will not be parallelized.
  • n_features (int | None, optional):
  • tqdm_style (TqdmStyle, optional):
  • device (Any, optional):
Matching = dict[int, set[int]] | dict[str, set[int]] | dict[int, set[str]] | dict[str, set[str]]
def max_connected_confusion_choice( y_pred: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_true: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], n_classes: int, n: int | None = None, threshold: int = 0) -> tuple[list[int], int]:
235def max_connected_confusion_choice(
236    y_pred: ArrayLike,
237    y_true: ArrayLike,
238    n_classes: int,
239    n: int | None = None,
240    threshold: int = 0,
241) -> tuple[list[int], int]:
242    """
243    Chooses the classes that are most confused for each other according to
244    some confusion graph scheme.
245
246    Args:
247        y_pred (Tensor): A `(N,)` int tensor or an `(N, n_classes)`
248            probabilities/logits float tensor
249        y_true (Tensor): A `(N,)` int tensor
250        n_classes (int): Number of classes in the dataset.
251        n (int, optional): Number of classes to choose. If `None`, returns the
252            classes in the largest connected component of the confusion graph.
253        threshold (int, optional): Ignore pairs of classes that are confused by
254            less than that number of samples. See also
255            `lcc.choice.confusion_graph`.
256
257    Returns:
258        An `int` list of `n` classes **or less**, and the total number of
259        confused sample number along these classes.
260    """
261    cg = confusion_graph(y_pred, y_true, n_classes, threshold=threshold)
262    hcg, w = heaviest_connected_subgraph(cg, max_size=n, strict=False)
263    return Tensor(list(hcg.nodes)).int().tolist(), int(w)

Chooses the classes that are most confused for each other according to some confusion graph scheme.

Arguments:
  • y_pred (Tensor): A (N,) int tensor or an (N, n_classes) probabilities/logits float tensor
  • y_true (Tensor): A (N,) int tensor
  • n_classes (int): Number of classes in the dataset.
  • n (int, optional): Number of classes to choose. If None, returns the classes in the largest connected component of the confusion graph.
  • threshold (int, optional): Ignore pairs of classes that are confused by less than that number of samples. See also lcc.choice.confusion_graph.
Returns:

An int list of n classes or less, and the total number of confused sample number along these classes.

def otm_matching_predicates( y_a: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_b: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], matching: dict[int, set[int]] | dict[str, set[int]] | dict[int, set[str]] | dict[str, set[str]], c_a: int | None = None) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]:
132def otm_matching_predicates(
133    y_a: ArrayLike,
134    y_b: ArrayLike,
135    matching: Matching,
136    c_a: int | None = None,
137) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
138    """
139    Let `y_a` be `(N,)` integer array with values in $\\\\{ 0, 1, ..., c_a - 1
140    \\\\}$ (if the argument `c_a` is `None`, it is inferred to be `y_a.max() +
141    1`). If `y_a[i] == j`, then it is understood that the $i$-th sample (in
142    some dataset, say `x`) is in class $j$, which for disambiguation we'll call
143    the $a$-class $j$.
144
145    Likewise, let `y_b` be `(N,)` integer array with values $\\\\{ 0, 1, ...,
146    c_b - 1 \\\\}$. If `y_b[i] == j`, then it is understood that the $i$-th
147    sample `x[i]` is in $b$-class $j$.
148
149    Finally, let `matching` be a (possibley one-to-many) matching between the
150    $a$-classes and the $b$-classes. In other words each $a$-class corresponds to
151    some set of $b$-classes.
152
153    This method returns four boolean arrays with shape `(c_a, N)`, which in my
154    head I call *"true-louvain-miss-excess"*:
155
156    1. `p1` is simply given by `p1[a] = (y_a == a)`, or in other words, `p1[a,
157       i]` is `True` if and only if the $i$-th sample is in $a$-class `a`.
158    2. `p2[a, i]` is `True` if and only if the $i$-th sample is in a $b$-class
159       that has matched to $a$-class `a`.
160    3. `p3` is (informally) given by `p3[a] = (p1[a] and not p2[a])`. In other
161       words, `p3[a, i]` is `True` if sample $i$ is in $a$-class `a` but not in
162       any $b$-class matched with `a`.
163    4. `p4` is the "dual" of `p3`: `p4[a] = (p2[a] and not p1[a])`. In other
164       words, `p4[a, i]` is `True` if sample $i$ is not in $a$-class `a`, but is
165       in a $b$-class matched with `a`.
166
167    I hope this all makes sense.
168
169    Args:
170        y_a (np.ndarray): A `(N,)` integer array with values in $\\\\{ 0, 1,
171            ..., c_a - 1 \\\\}$ for some $c_a > 0$.
172        y_b (np.ndarray): A `(N,)` integer array with values in $\\\\{ 0, 1,
173            ..., c_b - 1 \\\\}$ for some $c_b > 0$.
174        matching (Matching): A partition of
175            $\\\\{ 0, ..., c_b - 1 \\\\}$ into $c_a$ sets. The $i$-th set is
176            understood to be the set of all classes of `y_b` that matched with
177            the $i$-th class of `y_a`. If some keys are strings, they must be
178            convertible to ints. This has probably been produced by
179            `lcc.correction.class_otm_matching`.
180        c_a (int | None, optional): Number of $a$-classes. Useful if `y_a`
181            does not contain all the possible classes of the dataset at hand.
182            If `None`, then `y_a` is assumed to contain all classes, and so `c_a
183            = y_a.max() + 1`.
184    """
185    y_a, y_b = to_int_array(y_a), to_int_array(y_b)
186    matching = to_int_matching(matching)
187    if (la := len(y_a)) != (lb := len(y_b)):
188        raise ValueError(
189            f"y_a and y_b must have the same length, got {la} and {lb}"
190        )
191    c_a = c_a or int(y_a.max() + 1)
192    p1 = [y_a == a for a in range(c_a)]
193    p2 = [
194        (
195            np.sum(
196                [np.zeros_like(y_b)] + [y_b == b for b in matching.get(a, [])],
197                axis=0,
198            )
199            > 0
200            if a in matching
201            else np.full_like(y_b, False, dtype=bool)  # a isn't matched in m
202        )
203        for a in range(c_a)
204    ]
205    p3 = [p1[a] & ~p2[a] for a in range(c_a)]
206    p4 = [p2[a] & ~p1[a] for a in range(c_a)]
207    return np.array(p1), np.array(p2), np.array(p3), np.array(p4)

Let y_a be (N,) integer array with values in $\{ 0, 1, ..., c_a - 1 \}$ (if the argument c_a is None, it is inferred to be y_a.max() + 1). If y_a[i] == j, then it is understood that the $i$-th sample (in some dataset, say x) is in class $j$, which for disambiguation we'll call the $a$-class $j$.

Likewise, let y_b be (N,) integer array with values $\{ 0, 1, ..., c_b - 1 \}$. If y_b[i] == j, then it is understood that the $i$-th sample x[i] is in $b$-class $j$.

Finally, let matching be a (possibley one-to-many) matching between the $a$-classes and the $b$-classes. In other words each $a$-class corresponds to some set of $b$-classes.

This method returns four boolean arrays with shape (c_a, N), which in my head I call "true-louvain-miss-excess":

  1. p1 is simply given by p1[a] = (y_a == a), or in other words, p1[a, i] is True if and only if the $i$-th sample is in $a$-class a.
  2. p2[a, i] is True if and only if the $i$-th sample is in a $b$-class that has matched to $a$-class a.
  3. p3 is (informally) given by p3[a] = (p1[a] and not p2[a]). In other words, p3[a, i] is True if sample $i$ is in $a$-class a but not in any $b$-class matched with a.
  4. p4 is the "dual" of p3: p4[a] = (p2[a] and not p1[a]). In other words, p4[a, i] is True if sample $i$ is not in $a$-class a, but is in a $b$-class matched with a.

I hope this all makes sense.

Arguments:
  • y_a (np.ndarray): A (N,) integer array with values in $\{ 0, 1, ..., c_a - 1 \}$ for some $c_a > 0$.
  • y_b (np.ndarray): A (N,) integer array with values in $\{ 0, 1, ..., c_b - 1 \}$ for some $c_b > 0$.
  • matching (Matching): A partition of $\{ 0, ..., c_b - 1 \}$ into $c_a$ sets. The $i$-th set is understood to be the set of all classes of y_b that matched with the $i$-th class of y_a. If some keys are strings, they must be convertible to ints. This has probably been produced by lcc.correction.class_otm_matching.
  • c_a (int | None, optional): Number of $a$-classes. Useful if y_a does not contain all the possible classes of the dataset at hand. If None, then y_a is assumed to contain all classes, and so c_a = y_a.max() + 1.
def peer_pressure_clustering( ds: lcc.datasets.BatchedTensorDataset, k: int, strategy: pytorch_lightning.strategies.strategy.Strategy | lightning_fabric.fabric.Fabric | None = None, n_features: int | None = None, tqdm_style: Optional[Literal['notebook', 'console', 'none']] = None) -> numpy.ndarray:
100def peer_pressure_clustering(
101    ds: BatchedTensorDataset,
102    k: int,
103    strategy: Strategy | Fabric | None = None,
104    n_features: int | None = None,
105    tqdm_style: TqdmStyle = None,
106) -> np.ndarray:
107    """
108    Nearest-neighbor peer pressure clustering with weighted mean cluster
109    conductance as an objective function.
110
111    Args:
112        ds (BatchedTensorDataset):
113        k (int):
114        strategy (Strategy | Fabric | None, optional): Defaults to `None`,
115            meaning that the algorithm will not be parallelized.
116        n_features (int | None, optional):
117        tqdm_style (TqdmStyle, optional):
118    """
119    graph = knn_graph(
120        ds, k, strategy=strategy, n_features=n_features, tqdm_style=tqdm_style
121    )
122    a = nx.to_scipy_sparse_array(graph, format="csc")
123    y_clst, _ = _ppc(a, tqdm_style=tqdm_style)
124    return y_clst

Nearest-neighbor peer pressure clustering with weighted mean cluster conductance as an objective function.

Arguments:
  • ds (BatchedTensorDataset):
  • k (int):
  • strategy (Strategy | Fabric | None, optional): Defaults to None, meaning that the algorithm will not be parallelized.
  • n_features (int | None, optional):
  • tqdm_style (TqdmStyle, optional):
class RandomizedLCCLoss(lcc.correction.LCCLoss):
 27class RandomizedLCCLoss(LCCLoss):
 28    """
 29    A LCC loss function that pulls misclustered samples towards a CC sample in
 30    the same class.
 31
 32    In principle, this implies some sort of exhaustive search since a MC sample
 33    has to be compared to *every* CC sample in the same class. This is what
 34    `lcc.correction.ExactLCCLoss` does. Here, to save on compute and time, only
 35    a few CC samples are randomly selected in each cluster and used as potential
 36    targets.
 37    """
 38
 39    ccspc: int
 40    n_classes: int
 41    targets: dict[int, Tensor] = {}
 42    tqdm_style: TqdmStyle
 43    matching: Matching
 44
 45    def __call__(
 46        self, z: Tensor, y_true: ArrayLike, y_clst: ArrayLike
 47    ) -> Tensor:
 48        """
 49        Derives the clustering correction loss from a tensor of latent
 50        representation `z` and dict of targets (see
 51        `lcc.correction.lcc_targets`).
 52
 53        First, recall that the values of `target` (as produced
 54        `lcc.correction.lcc_targets`) are `(k, d)` tensors, for some length
 55        `k`.
 56
 57        Let's say `a` is a misclustered latent sample (a.k.a. a row of `z`) in
 58        true class `i_true`, and that `(b_1, ..., b_k)` are the rows of
 59        `targets[i_true]`. Then `a` contributes a term to the LCC loss equal to
 60        the distance between `a` and the closest `b_j`, divided by
 61        $\\\\sqrt{d}$.
 62
 63        It is possible that `i_true` is not in the keys of `targets`, in which
 64        case the contribution of `a` to the LCC loss is zero. In particular, if
 65        `targets` is empty, then the LCC loss is zero.
 66
 67        Args:
 68            z (Tensor): The tensor of latent representations. *Do not* mask it
 69                before passing it to this method.  The correctly samples and the
 70                missclustered samples are automatically separated.
 71            y_true (ArrayLike): A `(N,)` integer array of true labels.
 72            y_clst (ArrayLike): A `(N,)` integer array of the cluster labels.
 73        """
 74        if not self.targets:
 75            # ↓ actually need grad?
 76            return torch.tensor(0.0, requires_grad=True).to(z.device)
 77        z, y_true = z.flatten(1), to_int_tensor(y_true)
 78        p_mc, _ = _mc_cc_predicates(
 79            y_true, y_clst, self.matching, n_classes=self.n_classes
 80        )
 81        sqrt_d, losses = sqrt(z.shape[-1]), []
 82        for i_true, p_mc_i_true in enumerate(p_mc):
 83            if not (
 84                i_true in self.targets and len(self.targets[i_true]) > 0
 85            ):  # no targets in this true class
 86                continue
 87            if not p_mc_i_true.any():  # every sample is correctly clustered
 88                continue
 89            d = (
 90                torch.cdist(z[p_mc_i_true], self.targets[i_true].to(z.device))
 91                / sqrt_d
 92            )
 93            losses.append(d.min(dim=-1).values)
 94        if not losses:
 95            return torch.tensor(0.0, requires_grad=True).to(z.device)
 96        return torch.concat(losses).mean()
 97
 98    def __init__(
 99        self,
100        n_classes: int,
101        ccspc: int = 100,
102        tqdm_style: TqdmStyle = None,
103        strategy: Strategy | Fabric | None = None,
104    ) -> None:
105        super().__init__(strategy=strategy)
106        self.n_classes, self.ccspc = n_classes, ccspc
107        self.tqdm_style = tqdm_style
108
109    def sync(self, **kwargs: Any) -> None:
110        if self.strategy is None:
111            return
112        path = self._get_tmp_dir()
113        gr = self.strategy.global_rank
114        st.save_file(
115            {str(k): v for k, v in self.targets.items()},
116            path / f"targets.{gr}",
117        )
118        self.strategy.barrier()
119        for r in range(self.strategy.world_size):
120            if r == self.strategy.global_rank:
121                continue  # data from this rank is already in self.targets
122            data = st.load_file(path / f"targets.{r}")
123            self.targets.update({int(k): v for k, v in data.items()})
124        return super().sync(**kwargs)
125
126    def update(
127        self,
128        dl: DataLoader,
129        y_true: ArrayLike,
130        y_clst: ArrayLike,
131        matching: Matching,
132    ) -> None:
133        """
134        This method updates the `targets` attribute of this instance. It is a
135        dict containing the following:
136        - the keys are *among* true classes (unique values of `y_true`); let's
137          say that `i_true` is a key that owns `k` clusters;
138        - the associated value a `(n, d)` tensor, where `d` is the latent
139          dimension, whose rows are among correctly clustered samples in true
140          class `i_true`.  For example, if `ccspc` is $1$, then `n` is the
141          number of clusters matched with `i_true`, say `k`. Otherwise, `n <= k
142          * ccspc`.
143
144        Under the hood, this method first choose the samples by their index
145        based on the "correctly clustered" predicate of `_mc_cc_predicates`.
146        Then, the whole dataset is iterated to collect the actual samples.
147
148        Args:
149            dl (DataLoader): An unsharded dataloader over a tensor dataset.
150            y_true (ArrayLike): A `(N,)` integer array.
151            y_clst (ArrayLike): A `(N,)` integer array.
152            matching (Matching): Produced by
153               `lcc.correction.class_otm_matching`.
154        """
155        self.matching = to_int_matching(matching)
156        _, p_cc = _mc_cc_predicates(
157            y_true, y_clst, self.matching, self.n_classes
158        )
159        # i_true (assigned to this rank) -> some indices of CC samples
160        indices: dict[int, set[int]] = {
161            i_true: set() for i_true in self._distribute_labels(y_true)
162        }
163        for i_true in indices:
164            for j_clst in self.matching[i_true]:
165                # p: (N,) CC in i_true and j_clst
166                p = p_cc[i_true] & (y_clst == j_clst)
167                s = np.random.choice(np.where(p)[0], self.ccspc)
168                indices[i_true].update(s)
169        n_seen, n_todo = 0, sum(len(v) for v in indices.values())
170        # i_true (assigned to this rank) -> some CC samples
171        result: dict[int, list[Tensor]] = defaultdict(list)
172        tqdm = make_tqdm(self.tqdm_style)
173        progress = tqdm(dl, f"Finding correction targets (ccspc={self.ccspc})")
174        for z, *_ in progress:
175            for i_true, idxs in indices.items():
176                lst = [idx for idx in idxs if n_seen <= idx < n_seen + len(z)]
177                for idx in lst:
178                    result[i_true].append(z[idx - n_seen])
179                    n_todo -= 1
180            if n_todo <= 0:
181                break
182            n_seen += len(z)
183        if n_todo > 0:
184            logging.warning(
185                "Some correction targets could not be found "
186                "(n_seen={}, n_todo={})",
187                n_seen,
188                n_todo,
189            )
190        self.targets = {
191            k: torch.stack(v).flatten(1)
192            for k, v in result.items()
193            if v  # should already be is non-empty but just to make sure...
194        }

A LCC loss function that pulls misclustered samples towards a CC sample in the same class.

In principle, this implies some sort of exhaustive search since a MC sample has to be compared to every CC sample in the same class. This is what lcc.correction.ExactLCCLoss does. Here, to save on compute and time, only a few CC samples are randomly selected in each cluster and used as potential targets.

RandomizedLCCLoss( n_classes: int, ccspc: int = 100, tqdm_style: Optional[Literal['notebook', 'console', 'none']] = None, strategy: pytorch_lightning.strategies.strategy.Strategy | lightning_fabric.fabric.Fabric | None = None)
 98    def __init__(
 99        self,
100        n_classes: int,
101        ccspc: int = 100,
102        tqdm_style: TqdmStyle = None,
103        strategy: Strategy | Fabric | None = None,
104    ) -> None:
105        super().__init__(strategy=strategy)
106        self.n_classes, self.ccspc = n_classes, ccspc
107        self.tqdm_style = tqdm_style
ccspc: int
n_classes: int
targets: dict[int, torch.Tensor] = {}
tqdm_style: Optional[Literal['notebook', 'console', 'none']]
matching: dict[int, set[int]] | dict[str, set[int]] | dict[int, set[str]] | dict[str, set[str]]
def sync(self, **kwargs: Any) -> None:
109    def sync(self, **kwargs: Any) -> None:
110        if self.strategy is None:
111            return
112        path = self._get_tmp_dir()
113        gr = self.strategy.global_rank
114        st.save_file(
115            {str(k): v for k, v in self.targets.items()},
116            path / f"targets.{gr}",
117        )
118        self.strategy.barrier()
119        for r in range(self.strategy.world_size):
120            if r == self.strategy.global_rank:
121                continue  # data from this rank is already in self.targets
122            data = st.load_file(path / f"targets.{r}")
123            self.targets.update({int(k): v for k, v in data.items()})
124        return super().sync(**kwargs)

Distributes or shares this object's data across all ranks. Call this after each ranks called update. Is just a barrier by default.

def update( self, dl: torch.utils.data.dataloader.DataLoader, y_true: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_clst: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], matching: dict[int, set[int]] | dict[str, set[int]] | dict[int, set[str]] | dict[str, set[str]]) -> None:
126    def update(
127        self,
128        dl: DataLoader,
129        y_true: ArrayLike,
130        y_clst: ArrayLike,
131        matching: Matching,
132    ) -> None:
133        """
134        This method updates the `targets` attribute of this instance. It is a
135        dict containing the following:
136        - the keys are *among* true classes (unique values of `y_true`); let's
137          say that `i_true` is a key that owns `k` clusters;
138        - the associated value a `(n, d)` tensor, where `d` is the latent
139          dimension, whose rows are among correctly clustered samples in true
140          class `i_true`.  For example, if `ccspc` is $1$, then `n` is the
141          number of clusters matched with `i_true`, say `k`. Otherwise, `n <= k
142          * ccspc`.
143
144        Under the hood, this method first choose the samples by their index
145        based on the "correctly clustered" predicate of `_mc_cc_predicates`.
146        Then, the whole dataset is iterated to collect the actual samples.
147
148        Args:
149            dl (DataLoader): An unsharded dataloader over a tensor dataset.
150            y_true (ArrayLike): A `(N,)` integer array.
151            y_clst (ArrayLike): A `(N,)` integer array.
152            matching (Matching): Produced by
153               `lcc.correction.class_otm_matching`.
154        """
155        self.matching = to_int_matching(matching)
156        _, p_cc = _mc_cc_predicates(
157            y_true, y_clst, self.matching, self.n_classes
158        )
159        # i_true (assigned to this rank) -> some indices of CC samples
160        indices: dict[int, set[int]] = {
161            i_true: set() for i_true in self._distribute_labels(y_true)
162        }
163        for i_true in indices:
164            for j_clst in self.matching[i_true]:
165                # p: (N,) CC in i_true and j_clst
166                p = p_cc[i_true] & (y_clst == j_clst)
167                s = np.random.choice(np.where(p)[0], self.ccspc)
168                indices[i_true].update(s)
169        n_seen, n_todo = 0, sum(len(v) for v in indices.values())
170        # i_true (assigned to this rank) -> some CC samples
171        result: dict[int, list[Tensor]] = defaultdict(list)
172        tqdm = make_tqdm(self.tqdm_style)
173        progress = tqdm(dl, f"Finding correction targets (ccspc={self.ccspc})")
174        for z, *_ in progress:
175            for i_true, idxs in indices.items():
176                lst = [idx for idx in idxs if n_seen <= idx < n_seen + len(z)]
177                for idx in lst:
178                    result[i_true].append(z[idx - n_seen])
179                    n_todo -= 1
180            if n_todo <= 0:
181                break
182            n_seen += len(z)
183        if n_todo > 0:
184            logging.warning(
185                "Some correction targets could not be found "
186                "(n_seen={}, n_todo={})",
187                n_seen,
188                n_todo,
189            )
190        self.targets = {
191            k: torch.stack(v).flatten(1)
192            for k, v in result.items()
193            if v  # should already be is non-empty but just to make sure...
194        }

This method updates the targets attribute of this instance. It is a dict containing the following:

  • the keys are among true classes (unique values of y_true); let's say that i_true is a key that owns k clusters;
  • the associated value a (n, d) tensor, where d is the latent dimension, whose rows are among correctly clustered samples in true class i_true. For example, if ccspc is $1$, then n is the number of clusters matched with i_true, say k. Otherwise, n <= k <ul> <li>ccspc.

Under the hood, this method first choose the samples by their index based on the "correctly clustered" predicate of _mc_cc_predicates. Then, the whole dataset is iterated to collect the actual samples.

Arguments:
  • dl (DataLoader): An unsharded dataloader over a tensor dataset.
  • y_true (ArrayLike): A (N,) integer array.
  • y_clst (ArrayLike): A (N,) integer array.
  • matching (Matching): Produced by lcc.correction.class_otm_matching.
def top_confusion_pairs( y_pred: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_true: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], n_classes: int, n_pairs: int | None = None, threshold: int = 0) -> list[tuple[int, int]]:
266def top_confusion_pairs(
267    y_pred: ArrayLike,
268    y_true: ArrayLike,
269    n_classes: int,
270    n_pairs: int | None = None,
271    threshold: int = 0,
272) -> list[tuple[int, int]]:
273    """
274    Returns the top `n_pairs` top pairs of labels that exhibit the most
275    confusion. The confusion between two labels $a$ and $b$ is the number of
276    samples in true class $a$ that are predicted as class $b$, plus the number
277    of samples in true class $b$ that are predicted as class $a$.
278
279    Example:
280        >>> y_pred, y_true = [0, 0, 1, 1, 2, 2], [0, 1, 1, 2, 2, 0]
281        >>> top_confusion_pairs(y_pred, y_true, n_classes=3, n_pairs=2)
282        [(1, 2), (0, 2)]
283
284    Args:
285        y_pred (Tensor): A `(N,)` int tensor or an `(N, n_classes)`
286            probabilities/logits float tensor
287        y_true (Tensor): A `(N,)` int tensor
288        n_classes (int): Number of classes in the dataset
289        n_pairs (int | None, optional): Number of desired pairs. The actual
290            result might have less. If `None`, returns all pairs of labels that
291            have at lease `threshold` confused samples.
292        threshold (int, optional): Minimum number of confused samples between a
293            pair of labels to be included in the list
294
295    Returns:
296        The top `n_pairs` pairs **or less** of labels that exhibit the most
297        confusion.
298    """
299    y_pred, y_true = to_int_tensor(y_pred), to_int_tensor(y_true)
300    cm = multiclass_confusion_matrix(y_pred, y_true, n_classes).numpy()
301    cm = cm + cm.T  # Confusion in either direction
302    cm = cm * (1 - np.eye(len(cm)))  # Remove the diagonal
303    idx = cm.argsort(axis=None)  # Flat indices
304    idx = np.flip(idx)
305    cp = np.stack(np.unravel_index(idx, cm.shape)).T
306    lst = [(i, j) for i, j in cp if cm[i, j] > threshold and i < j]
307    return lst if n_pairs is None else lst[:n_pairs]

Returns the top n_pairs top pairs of labels that exhibit the most confusion. The confusion between two labels $a$ and $b$ is the number of samples in true class $a$ that are predicted as class $b$, plus the number of samples in true class $b$ that are predicted as class $a$.

Example:
>>> y_pred, y_true = [0, 0, 1, 1, 2, 2], [0, 1, 1, 2, 2, 0]
>>> top_confusion_pairs(y_pred, y_true, n_classes=3, n_pairs=2)
[(1, 2), (0, 2)]
Arguments:
  • y_pred (Tensor): A (N,) int tensor or an (N, n_classes) probabilities/logits float tensor
  • y_true (Tensor): A (N,) int tensor
  • n_classes (int): Number of classes in the dataset
  • n_pairs (int | None, optional): Number of desired pairs. The actual result might have less. If None, returns all pairs of labels that have at lease threshold confused samples.
  • threshold (int, optional): Minimum number of confused samples between a pair of labels to be included in the list
Returns:

The top n_pairs pairs or less of labels that exhibit the most confusion.