lcc.correction
Everything related to LCC.
1"""Everything related to LCC.""" 2 3from .choice import ( 4 LCC_CLASS_SELECTIONS, 5 GraphTotallyDisconnected, 6 LCCClassSelection, 7 choose_classes, 8 confusion_graph, 9 heaviest_connected_subgraph, 10 max_connected_confusion_choice, 11 top_confusion_pairs, 12) 13from .clustering import ( 14 class_otm_matching, 15 otm_matching_predicates, 16) 17from .loss import ExactLCCLoss, LCCLoss, RandomizedLCCLoss 18from .louvain import louvain_clustering 19from .peer_pressure import peer_pressure_clustering 20from .utils import Matching 21 22__all__ = [ 23 "choose_classes", 24 "class_otm_matching", 25 "confusion_graph", 26 "ExactLCCLoss", 27 "GraphTotallyDisconnected", 28 "heaviest_connected_subgraph", 29 "LCC_CLASS_SELECTIONS", 30 "LCCClassSelection", 31 "LCCLoss", 32 "louvain_clustering", 33 "Matching", 34 "max_connected_confusion_choice", 35 "otm_matching_predicates", 36 "peer_pressure_clustering", 37 "RandomizedLCCLoss", 38 "top_confusion_pairs", 39]
95def choose_classes( 96 y_true: ArrayLike, 97 y_pred: ArrayLike, 98 policy: LCCClassSelection | Literal["all"] | None = None, 99) -> list[int] | None: 100 """ 101 Given true and predicted labels, select classes whose samples should undergo 102 LCC based on some policy. See `lcc.correction.LCCClassSelection`. 103 104 For convenience, this method returns `None` if all classes should be 105 considered. 106 107 Warning: 108 When selecting a `"top_<N>"` policy, the returned list may have fewer 109 than `N` elements. For example, this happens when there are fewer than 110 `N` classes in the dataset. 111 """ 112 y_true, y_pred = to_int_tensor(y_true), to_int_tensor(y_pred) 113 if policy is None: 114 return None 115 if policy == "all": 116 logging.warning( 117 "LCC class selection policy 'all' is deprecated. Use `None` instead (which has the same effect)" 118 ) 119 return None 120 n_classes = y_true.unique().numel() 121 if policy.startswith("top_pair_"): 122 n = int(policy[9:]) 123 if n >= n_classes: 124 return None 125 pairs = top_confusion_pairs(y_pred, y_true, n_classes, n_pairs=n) 126 return list(set(sum(pairs, ()))) 127 try: 128 if policy.startswith("top_connected_"): 129 n = int(policy[14:]) 130 if n >= n_classes: 131 return None 132 return max_connected_confusion_choice( 133 y_pred, y_true, n_classes, n 134 )[0] 135 return max_connected_confusion_choice(y_pred, y_true, n_classes)[0] 136 except GraphTotallyDisconnected: 137 return None
Given true and predicted labels, select classes whose samples should undergo
LCC based on some policy. See lcc.correction.LCCClassSelection
.
For convenience, this method returns None
if all classes should be
considered.
Warning:
When selecting a
"top_<N>"
policy, the returned list may have fewer thanN
elements. For example, this happens when there are fewer thanN
classes in the dataset.
67def class_otm_matching(y_a: ArrayLike, y_b: ArrayLike) -> Matching: 68 """ 69 Let `y_a` and `y_b` be `(N,)` integer arrays. We think of them as classes on 70 some dataset, say `x`, which we call respectively *$a$-classes* and 71 *$b$-classes*. This method performs a one-to-many matching from the classes 72 in `y_a` to the classes in `y_b` to overall maximize the cardinality of the 73 intersection between samples labeled by $a$ and matched $b$-classes. 74 75 Example: 76 77 >>> y_a = np.array([1, 1, 1, 1, 2, 2, 3, 4, 4]) 78 >>> y_b = np.array([10, 50, 10, 20, 20, 20, 30, 30, 30]) 79 >>> class_otm_matching(y_a, y_b) 80 {1: {10, 50}, 2: {20}, 3: set(), 4: {30}} 81 82 Here, `y_a` assigns class `1` to samples 0 to 3, label `2` to samples 4 83 and 5 etc. On the other hand, `y_b` assigns its own classes to the 84 dataset. What is the best way to regroup classes of `y_b` to 85 approximate the labelling of `y_a`? The `class_otm_matching` return 86 value argues that classes `10` and `15` should be regrouped under `1` 87 (they fit neatly), label `20` should be renamed to `2` (eventhough it 88 "leaks" a little, in that sample 3 is labelled with `1` and `20`), and 89 class `30` should be renamed to `4`. No class in `y_b` is assigned to 90 class `3` in this matching. 91 92 Note: 93 There are no restriction on the values of `y_a` and `y_b`. In 94 particular, they need not be distinct: the following works fine 95 96 >>> y_a = np.array([1, 1, 1, 1, 2, 2, 3, 4, 4]) 97 >>> y_b = np.array([1, 5, 1, 2, 2, 2, 3, 3, 3]) 98 >>> class_otm_matching(y_a, y_b) 99 {1: {1, 5}, 2: {2}, 3: set(), 4: {3}} 100 101 Warning: 102 Negative labels in `y_a` or `y_b` are ignored. So the output matching 103 dict will never have negative keys, and the sets will never have 104 negative values either. 105 106 Args: 107 y_a (ArrayLike): A `(N,)` integer array. 108 y_b (ArrayLike): A `(N,)` integer array. 109 110 Returns: 111 A dict that maps each class in `y_a` to the set of classes in `y_b` that 112 it has matched. 113 """ 114 y_a, y_b, match_graph = to_int_array(y_a), to_int_array(y_b), nx.DiGraph() 115 for i, j in product(np.unique(y_a), np.unique(y_b)): 116 if i < 0 or j < 0: 117 continue 118 n = np.sum((y_a == i) & (y_b == j)) 119 match_graph.add_edge(f"a_{i}", f"b_{j}", weight=n) 120 matching = _otm_matching( 121 match_graph, 122 [f"a_{i}" for i in np.unique(y_a)], 123 [f"b_{i}" for i in np.unique(y_b)], 124 mode="max", 125 ) 126 return { 127 int(a_i.split("_")[-1]): {int(b_j.split("_")[-1]) for b_j in b_js} 128 for a_i, b_js in matching.items() 129 }
Let y_a
and y_b
be (N,)
integer arrays. We think of them as classes on
some dataset, say x
, which we call respectively $a$-classes and
$b$-classes. This method performs a one-to-many matching from the classes
in y_a
to the classes in y_b
to overall maximize the cardinality of the
intersection between samples labeled by $a$ and matched $b$-classes.
Example:
>>> y_a = np.array([1, 1, 1, 1, 2, 2, 3, 4, 4]) >>> y_b = np.array([10, 50, 10, 20, 20, 20, 30, 30, 30]) >>> class_otm_matching(y_a, y_b) {1: {10, 50}, 2: {20}, 3: set(), 4: {30}}
Here,
y_a
assigns class1
to samples 0 to 3, label2
to samples 4 and 5 etc. On the other hand,y_b
assigns its own classes to the dataset. What is the best way to regroup classes ofy_b
to approximate the labelling ofy_a
? Theclass_otm_matching
return value argues that classes10
and15
should be regrouped under1
(they fit neatly), label20
should be renamed to2
(eventhough it "leaks" a little, in that sample 3 is labelled with1
and20
), and class30
should be renamed to4
. No class iny_b
is assigned to class3
in this matching.
Note:
There are no restriction on the values of
y_a
andy_b
. In particular, they need not be distinct: the following works fine>>> y_a = np.array([1, 1, 1, 1, 2, 2, 3, 4, 4]) >>> y_b = np.array([1, 5, 1, 2, 2, 2, 3, 3, 3]) >>> class_otm_matching(y_a, y_b) {1: {1, 5}, 2: {2}, 3: set(), 4: {3}}
Warning:
Negative labels in
y_a
ory_b
are ignored. So the output matching dict will never have negative keys, and the sets will never have negative values either.
Arguments:
- y_a (ArrayLike): A
(N,)
integer array. - y_b (ArrayLike): A
(N,)
integer array.
Returns:
A dict that maps each class in
y_a
to the set of classes iny_b
that it has matched.
62def confusion_graph( 63 y_pred: ArrayLike, y_true: ArrayLike, n_classes: int, threshold: int = 10 64) -> nx.Graph: 65 """ 66 Create a confusion graph from predicted and true labels. Two labels $a$ and 67 $b$ are confused if there is at least one sample belonging to class $a$ that 68 is predicted as class $b$, or vice versa. The nodes of the confusion graph 69 are labels, two labels are connected by an edge if they are confused by at 70 least `threshold` samples, and the edges' weight are the number of times two 71 labels are confused for each other. 72 73 Args: 74 y_pred (ArrayLike): A `(N,)` int tensor or an `(N, n_classes)` 75 probabilities/logits float tensor 76 y_true (ArrayLike): A `(N,)` int tensor 77 n_classes (int): 78 threshold (int, optional): Minimum number of times two classes must be 79 confused (in either direction) to be included in the graph. 80 81 Warning: 82 There are no loops, i.e. correct predictions are not reported in the 83 graph unlike in usual confusion matrices. 84 """ 85 y_pred, y_true = to_int_tensor(y_pred), to_int_tensor(y_true) 86 cm = multiclass_confusion_matrix(y_pred, y_true, num_classes=n_classes) 87 cm = cm + cm.T # Confusion in either direction 88 cg = nx.Graph() 89 for i, j in combinations(list(range(n_classes)), 2): 90 if (w := cm[i, j].item()) >= threshold: 91 cg.add_edge(i, j, weight=w) 92 return cg
Create a confusion graph from predicted and true labels. Two labels $a$ and
$b$ are confused if there is at least one sample belonging to class $a$ that
is predicted as class $b$, or vice versa. The nodes of the confusion graph
are labels, two labels are connected by an edge if they are confused by at
least threshold
samples, and the edges' weight are the number of times two
labels are confused for each other.
Arguments:
- y_pred (ArrayLike): A
(N,)
int tensor or an(N, n_classes)
probabilities/logits float tensor - y_true (ArrayLike): A
(N,)
int tensor - n_classes (int):
- threshold (int, optional): Minimum number of times two classes must be confused (in either direction) to be included in the graph.
Warning:
There are no loops, i.e. correct predictions are not reported in the graph unlike in usual confusion matrices.
29class ExactLCCLoss(LCCLoss): 30 """LCC loss that corrects missclustered samples using their CC KNNs""" 31 32 k: int 33 n_classes: int 34 tqdm_style: TqdmStyle 35 matching: dict[int, set[int]] 36 37 # ↓ i_clst -> (knn idx, tensor of all CC samples in that clst) 38 data: dict[int, tuple[faiss.IndexFlatL2, Tensor]] = {} 39 40 def __call__( 41 self, z: Tensor, y_true: ArrayLike, y_clst: ArrayLike 42 ) -> Tensor: 43 _, _, p_mc, _ = otm_matching_predicates( 44 y_true, y_clst, self.matching, c_a=self.n_classes 45 ) # p_mc: (n_classes, len(z)) 46 z = z.flatten(1) 47 terms = [] 48 for i_true in np.unique(to_int_array(y_true)): 49 if not p_mc[i_true].any(): 50 # No MC samples in this class and batch 51 continue 52 u = z[p_mc[i_true]] # MC samples in class i_true 53 d = [] # Distances of MC samples to candidate targets 54 for i_clst in self.matching[i_true]: 55 if i_clst not in self.data: 56 continue 57 knn, cc = self.data[i_clst] 58 # Find a candidate target in i_clst for each sample in u 59 _, j = knn.search(to_array(u).astype(np.float32), self.k) 60 j = to_int_tensor(j) # (len(u), k) 61 t = cc[j].mean(dim=1).to(u.device) # (n, n_features) 62 # Save distance to these candidate targets 63 d.append( 64 torch.norm(u - t, dim=1) / sqrt(u.shape[-1]) # (len(u),) 65 ) 66 if d: 67 terms.append( 68 torch.stack(d) # (?, len(u)) 69 .min(dim=0) 70 .values # (len(u),) 71 ) 72 if not terms: 73 return torch.tensor(0.0, requires_grad=True).to(z.device) 74 return torch.cat(terms).mean() 75 76 def __init__( 77 self, 78 n_classes: int, 79 k: int = 5, 80 tqdm_style: TqdmStyle = None, 81 strategy: Strategy | Fabric | None = None, 82 ) -> None: 83 super().__init__(strategy=strategy) 84 self.k, self.n_classes = k, n_classes 85 self.tqdm_style = tqdm_style 86 87 def sync(self, **kwargs: Any) -> None: 88 """ 89 Remember that every rank has its own subset of cluster to manage. Before 90 sync every rank's `self.data` only contains data pertaining to this 91 rank's clusters. 92 93 This method works in two steps. First, every rank writes its data to 94 some temporary directory. Then, every rank loads data from that 95 directory. 96 97 EZPZ 98 """ 99 if self.strategy is None: 100 return 101 path = self._get_tmp_dir() 102 gr = self.strategy.global_rank 103 st.save_file( 104 {str(i_clst): cc for i_clst, (_, cc) in self.data.items()}, 105 path / f"cc.{gr}", 106 ) 107 for i_clst, (idx, _) in self.data.items(): 108 idx = faiss.index_gpu_to_cpu(idx) 109 faiss.write_index(idx, str(path / f"knn.{i_clst}.{gr}")) 110 self.strategy.barrier() 111 for r in range(self.strategy.world_size): 112 if r == gr: 113 continue # data from this rank is already in self.data 114 ccs = st.load_file(path / f"cc.{r}") 115 gpu = faiss.StandardGpuResources() 116 for i_clst, cc in ccs.items(): # type: ignore 117 knn = faiss.read_index(str(path / f"knn.{i_clst}.{r}")) 118 knn = faiss.index_cpu_to_gpu(gpu, gr, knn) 119 self.data[int(i_clst)] = (knn, cc) 120 return super().sync(**kwargs) 121 122 def update( 123 self, 124 dl: DataLoader, 125 y_true: ArrayLike, 126 y_clst: ArrayLike, 127 matching: Matching, 128 ) -> None: 129 """ 130 Reminder: 131 `dl` has to iterate over the whole dataset, even if this method is 132 called in a distributed environment. The labels vectors must also 133 cover the whole dataset. 134 """ 135 self.matching = to_int_matching(matching) 136 y_clst = to_int_array(y_clst) 137 n_features = next(iter(dl))[0].flatten(1).shape[-1] 138 p1, p2, _, _ = otm_matching_predicates( 139 y_true, y_clst, self.matching, c_a=self.n_classes 140 ) 141 p_cc = (p1 & p2).sum(axis=0).astype(bool) # (n_samples,) 142 143 # Cluster labels that this rank has to manage 144 clsts = self._distribute_labels(y_clst) 145 # ↓ i_clst -> (knn idx, list of batches CC samples in this clst) 146 data: dict[int, tuple[faiss.IndexFlatL2, list[Tensor]]] = { 147 i_clst: (faiss.IndexFlatL2(n_features), []) for i_clst in clsts 148 } 149 if self.strategy is not None: 150 gpu = faiss.StandardGpuResources() 151 gr = self.strategy.global_rank 152 for i_clst, (knn, cc) in data.items(): 153 knn = faiss.index_cpu_to_gpu(gpu, gr, knn) 154 data[i_clst] = (knn, cc) 155 156 tqdm, n_seen = make_tqdm(self.tqdm_style), 0 157 for z, *_ in tqdm(dl, f"Building {len(data)} KNN indices"): 158 z = z.flatten(1) # (bs, n_feat.) 159 _y_clst = y_clst[n_seen : n_seen + len(z)] # (bs,) 160 _p_cc = p_cc[n_seen : n_seen + len(z)] # (bs,) 161 for i_clst in np.unique(_y_clst): 162 if i_clst not in data: 163 continue # Cluster not managed by this rank 164 # ↓ Mask for smpls in this batch that are CC and in i_clsts 165 _p_cc_i_clst = _p_cc & (_y_clst == i_clst) 166 if not _p_cc_i_clst.any(): 167 continue # No CC sample in cluster i_clst in this batch 168 _z = z[_p_cc_i_clst] 169 data[i_clst][0].add(to_array(_z).astype(np.float32)) 170 data[i_clst][1].append(_z) 171 n_seen += len(z) 172 173 # ↓ i_clst -> (knn idx, tensor of all CC samples in this clst) 174 # IF i_clst has at least one CC sample 175 self.data = { 176 i_clst: (idx, torch.cat(lst)) 177 for i_clst, (idx, lst) in data.items() 178 if lst 179 }
LCC loss that corrects missclustered samples using their CC KNNs
87 def sync(self, **kwargs: Any) -> None: 88 """ 89 Remember that every rank has its own subset of cluster to manage. Before 90 sync every rank's `self.data` only contains data pertaining to this 91 rank's clusters. 92 93 This method works in two steps. First, every rank writes its data to 94 some temporary directory. Then, every rank loads data from that 95 directory. 96 97 EZPZ 98 """ 99 if self.strategy is None: 100 return 101 path = self._get_tmp_dir() 102 gr = self.strategy.global_rank 103 st.save_file( 104 {str(i_clst): cc for i_clst, (_, cc) in self.data.items()}, 105 path / f"cc.{gr}", 106 ) 107 for i_clst, (idx, _) in self.data.items(): 108 idx = faiss.index_gpu_to_cpu(idx) 109 faiss.write_index(idx, str(path / f"knn.{i_clst}.{gr}")) 110 self.strategy.barrier() 111 for r in range(self.strategy.world_size): 112 if r == gr: 113 continue # data from this rank is already in self.data 114 ccs = st.load_file(path / f"cc.{r}") 115 gpu = faiss.StandardGpuResources() 116 for i_clst, cc in ccs.items(): # type: ignore 117 knn = faiss.read_index(str(path / f"knn.{i_clst}.{r}")) 118 knn = faiss.index_cpu_to_gpu(gpu, gr, knn) 119 self.data[int(i_clst)] = (knn, cc) 120 return super().sync(**kwargs)
Remember that every rank has its own subset of cluster to manage. Before
sync every rank's self.data
only contains data pertaining to this
rank's clusters.
This method works in two steps. First, every rank writes its data to some temporary directory. Then, every rank loads data from that directory.
EZPZ
122 def update( 123 self, 124 dl: DataLoader, 125 y_true: ArrayLike, 126 y_clst: ArrayLike, 127 matching: Matching, 128 ) -> None: 129 """ 130 Reminder: 131 `dl` has to iterate over the whole dataset, even if this method is 132 called in a distributed environment. The labels vectors must also 133 cover the whole dataset. 134 """ 135 self.matching = to_int_matching(matching) 136 y_clst = to_int_array(y_clst) 137 n_features = next(iter(dl))[0].flatten(1).shape[-1] 138 p1, p2, _, _ = otm_matching_predicates( 139 y_true, y_clst, self.matching, c_a=self.n_classes 140 ) 141 p_cc = (p1 & p2).sum(axis=0).astype(bool) # (n_samples,) 142 143 # Cluster labels that this rank has to manage 144 clsts = self._distribute_labels(y_clst) 145 # ↓ i_clst -> (knn idx, list of batches CC samples in this clst) 146 data: dict[int, tuple[faiss.IndexFlatL2, list[Tensor]]] = { 147 i_clst: (faiss.IndexFlatL2(n_features), []) for i_clst in clsts 148 } 149 if self.strategy is not None: 150 gpu = faiss.StandardGpuResources() 151 gr = self.strategy.global_rank 152 for i_clst, (knn, cc) in data.items(): 153 knn = faiss.index_cpu_to_gpu(gpu, gr, knn) 154 data[i_clst] = (knn, cc) 155 156 tqdm, n_seen = make_tqdm(self.tqdm_style), 0 157 for z, *_ in tqdm(dl, f"Building {len(data)} KNN indices"): 158 z = z.flatten(1) # (bs, n_feat.) 159 _y_clst = y_clst[n_seen : n_seen + len(z)] # (bs,) 160 _p_cc = p_cc[n_seen : n_seen + len(z)] # (bs,) 161 for i_clst in np.unique(_y_clst): 162 if i_clst not in data: 163 continue # Cluster not managed by this rank 164 # ↓ Mask for smpls in this batch that are CC and in i_clsts 165 _p_cc_i_clst = _p_cc & (_y_clst == i_clst) 166 if not _p_cc_i_clst.any(): 167 continue # No CC sample in cluster i_clst in this batch 168 _z = z[_p_cc_i_clst] 169 data[i_clst][0].add(to_array(_z).astype(np.float32)) 170 data[i_clst][1].append(_z) 171 n_seen += len(z) 172 173 # ↓ i_clst -> (knn idx, tensor of all CC samples in this clst) 174 # IF i_clst has at least one CC sample 175 self.data = { 176 i_clst: (idx, torch.cat(lst)) 177 for i_clst, (idx, lst) in data.items() 178 if lst 179 }
Reminder:
dl
has to iterate over the whole dataset, even if this method is called in a distributed environment. The labels vectors must also cover the whole dataset.
55class GraphTotallyDisconnected(ValueError): 56 """ 57 Raised in `lcc.correction.heaviest_connected_subgraph` when a graph is 58 totally disconnected (has no edges). 59 """
Raised in lcc.correction.heaviest_connected_subgraph
when a graph is
totally disconnected (has no edges).
140def heaviest_connected_subgraph( 141 graph: nx.Graph, 142 max_size: int | None = None, 143 strict: bool = False, 144 key: str = "weight", 145) -> tuple[nx.Graph, float]: 146 """ 147 Find the heaviest connected full subgraph of an undirected graph with 148 weighted edges. In other words, returns the connected component whose total 149 edge weight is the largest. 150 151 Under the hood, this function maintains a list of connected full subgraphs 152 and iteratively adds the heaviest edge to those subgraphs that have one if 153 its endpoints. Note that: 154 - if no graph touch the current edge, then it is added to the list as its 155 own subgraph; 156 - if a graph have both endpoints of the current edge, then the edge was 157 already part of that graph and it is not modified; 158 - graphs in the list that have already reached `max_size` are not modified. 159 160 Finally, the heaviest graph is returned. 161 162 Warning: 163 Setting `strict` to `True` can make the problem impossible, e.g. if 164 `graph` doesn't have a large enough connected component. In such cases, 165 a `RuntimeError` is raised. 166 167 Warning: 168 If the graph is totally disconnected (i.e. has no edges), then a 169 `GraphTotallyDisconnected` exception is raised, rather than returning a 170 subgraph with a single node. 171 172 Args: 173 graph (nx.Graph): Most likely the confusion graph returned by 174 `lcc.choice.confusion_graph` eh? 175 max_size (int | None, optional): If left to `None`, returns the 176 heaviest connected component. 177 strict (bool, optional): If `True`, the returned graph is guaranteed to 178 have exactly `max_size` nodes. If `False`, the returned graph may 179 have fewer (but never more) nodes. 180 key (str, optional): The edge attribute to use as weight. 181 182 Returns: 183 A connected subgraph and its total weight (see also `total_weight`). 184 """ 185 if not graph.edges: 186 raise GraphTotallyDisconnected() 187 _total_weight = partial(total_weight, key=key) 188 if max_size is None: 189 subgraphs = sorted( 190 map(graph.subgraph, nx.connected_components(graph)), 191 key=_total_weight, 192 reverse=True, 193 ) 194 return subgraphs[0], _total_weight(subgraphs[0]) 195 edges = sorted( 196 graph.edges(data=True), key=lambda e: e[2][key], reverse=True 197 ) 198 subgraphs = [] 199 for u, v, _ in tqdm(edges, desc="Finding heaviest subgraph"): 200 if not subgraphs: 201 subgraphs.append(graph.subgraph([u, v])) 202 continue 203 has_u = np.array([g.has_node(u) for g in subgraphs], dtype=bool) 204 has_v = np.array([g.has_node(v) for g in subgraphs], dtype=bool) 205 if np.any(has_u & has_v): # a graph already contains edge (u, v) 206 continue 207 p = (has_u | has_v) & (np.array(list(map(len, subgraphs))) < max_size) 208 to_extend = [g for i, g in enumerate(subgraphs) if p[i]] 209 subgraphs = [g for i, g in enumerate(subgraphs) if ~p[i]] 210 if to_extend: 211 for g in to_extend: 212 subgraphs.append(graph.subgraph(list(g.nodes) + [u, v])) 213 else: 214 subgraphs.append(graph.subgraph([u, v])) 215 if strict: 216 subgraphs = list(filter(lambda g: len(g) >= max_size, subgraphs)) 217 subgraphs = sorted( 218 subgraphs, 219 key=_total_weight, 220 reverse=True, 221 ) 222 if not subgraphs: 223 raise RuntimeError( 224 "Could not find heaviest subgraph with given size constraint. " 225 "Try again with strict=False." 226 ) 227 return subgraphs[0], _total_weight(subgraphs[0])
Find the heaviest connected full subgraph of an undirected graph with weighted edges. In other words, returns the connected component whose total edge weight is the largest.
Under the hood, this function maintains a list of connected full subgraphs and iteratively adds the heaviest edge to those subgraphs that have one if its endpoints. Note that:
- if no graph touch the current edge, then it is added to the list as its own subgraph;
- if a graph have both endpoints of the current edge, then the edge was already part of that graph and it is not modified;
- graphs in the list that have already reached
max_size
are not modified.
Finally, the heaviest graph is returned.
Warning:
Setting
strict
toTrue
can make the problem impossible, e.g. ifgraph
doesn't have a large enough connected component. In such cases, aRuntimeError
is raised.
Warning:
If the graph is totally disconnected (i.e. has no edges), then a
GraphTotallyDisconnected
exception is raised, rather than returning a subgraph with a single node.
Arguments:
- graph (nx.Graph): Most likely the confusion graph returned by
lcc.choice.confusion_graph
eh? - max_size (int | None, optional): If left to
None
, returns the heaviest connected component. - strict (bool, optional): If
True
, the returned graph is guaranteed to have exactlymax_size
nodes. IfFalse
, the returned graph may have fewer (but never more) nodes. - key (str, optional): The edge attribute to use as weight.
Returns:
A connected subgraph and its total weight (see also
total_weight
).
20class LCCLoss(ABC): 21 """Abstract class that encapsulates a loss function for LCC.""" 22 23 strategy: ParallelStrategy | Fabric | None 24 25 _tmp_dir: TemporaryDirectory | None = None 26 """ 27 A temporary directory that can be created if needed to save/load data 28 before/after sync. Cleaned up automatically in 29 `on_after_sync`. 30 """ 31 32 @abstractmethod 33 def __call__( 34 self, z: Tensor, y_true: ArrayLike, y_clst: ArrayLike 35 ) -> Tensor: 36 pass 37 38 def __init__(self, strategy: Strategy | Fabric | None = None) -> None: 39 # TODO: Log a warning if strategy is a Strategy but not a 40 # ParallelStrategy 41 self.strategy = ( 42 strategy 43 if isinstance(strategy, (ParallelStrategy, Fabric)) 44 else None 45 ) 46 47 def _distribute_labels(self, y: ArrayLike) -> set[int]: 48 """ 49 Given an array of labels, distributes unique labels across ranks. 50 51 Example: 52 Let's say the world size is 2. 53 54 >>> y = np.array([0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5]) 55 >>> self._distribute_labels(y) 56 {0, 2, 4} # on rank 0 57 {1, 3, 5} # on rank 1 58 """ 59 y = to_int_array(y) 60 if self.strategy is not None and self.strategy.world_size > 1: 61 ws, gr = self.strategy.world_size, self.strategy.global_rank 62 return {i for i in np.unique(y) if i % ws == gr} 63 return set(np.unique(y)) 64 65 def _get_tmp_dir(self) -> Path: 66 """ 67 On rank 0, acquires a temporary directory, broadcast the handler to all 68 ranks, and returns its path. 69 70 Warning: 71 In a distributed environment, this method must be called from all 72 ranks "at the same time". 73 74 Args: 75 broadcast (bool, optional): 76 """ 77 if self._tmp_dir is not None: 78 return Path(self._tmp_dir.name) 79 if self.strategy is None: 80 self._tmp_dir = TemporaryDirectory(prefix="lcc-") 81 return Path(self._tmp_dir.name) 82 if self.strategy.global_rank == 0: 83 self._tmp_dir = TemporaryDirectory(prefix="lcc-") 84 else: 85 self._tmp_dir = None 86 self._tmp_dir = self.strategy.broadcast(self._tmp_dir, src=0) 87 assert self._tmp_dir is not None 88 return Path(self._tmp_dir.name) 89 90 def on_after_sync(self, **kwargs: Any) -> None: 91 """ 92 Should be called on all ranks after the loss object has been synced. 93 """ 94 if self.strategy is not None: 95 self.strategy.barrier() 96 if self._tmp_dir is not None and ( 97 self.strategy is None or self.strategy.global_rank == 0 98 ): 99 self._tmp_dir.cleanup() 100 101 def on_before_sync(self, **kwargs: Any) -> None: 102 """ 103 Should be called on all ranks before syncing the loss object from rank 0 104 to other ranks. 105 """ 106 107 def sync(self, **kwargs: Any) -> None: 108 """ 109 Distributes or shares this object's data across all ranks. Call this 110 after each ranks called `update`. Is just a barrier by default. 111 """ 112 if self.strategy is not None: 113 self.strategy.barrier() 114 115 @abstractmethod 116 def update( 117 self, 118 dl: DataLoader, 119 y_true: ArrayLike, 120 y_clst: ArrayLike, 121 matching: Matching, 122 ) -> None: 123 """ 124 Updates the internal state of the loss function. Presumably called at 125 the begining of each epoch where LCC is to be applied. 126 127 The dataloader has to yield batches that are tuples of tensors, the 128 first of which is a 2D tensor of samples. 129 130 Warning: 131 If the construction of the loss object is distributed across 132 multiple ranks, make sure that `dl` iterate over the WHOLE dataset 133 (no distributed sampling). 134 """
Abstract class that encapsulates a loss function for LCC.
90 def on_after_sync(self, **kwargs: Any) -> None: 91 """ 92 Should be called on all ranks after the loss object has been synced. 93 """ 94 if self.strategy is not None: 95 self.strategy.barrier() 96 if self._tmp_dir is not None and ( 97 self.strategy is None or self.strategy.global_rank == 0 98 ): 99 self._tmp_dir.cleanup()
Should be called on all ranks after the loss object has been synced.
101 def on_before_sync(self, **kwargs: Any) -> None: 102 """ 103 Should be called on all ranks before syncing the loss object from rank 0 104 to other ranks. 105 """
Should be called on all ranks before syncing the loss object from rank 0 to other ranks.
107 def sync(self, **kwargs: Any) -> None: 108 """ 109 Distributes or shares this object's data across all ranks. Call this 110 after each ranks called `update`. Is just a barrier by default. 111 """ 112 if self.strategy is not None: 113 self.strategy.barrier()
Distributes or shares this object's data across all ranks. Call this
after each ranks called update
. Is just a barrier by default.
115 @abstractmethod 116 def update( 117 self, 118 dl: DataLoader, 119 y_true: ArrayLike, 120 y_clst: ArrayLike, 121 matching: Matching, 122 ) -> None: 123 """ 124 Updates the internal state of the loss function. Presumably called at 125 the begining of each epoch where LCC is to be applied. 126 127 The dataloader has to yield batches that are tuples of tensors, the 128 first of which is a 2D tensor of samples. 129 130 Warning: 131 If the construction of the loss object is distributed across 132 multiple ranks, make sure that `dl` iterate over the WHOLE dataset 133 (no distributed sampling). 134 """
Updates the internal state of the loss function. Presumably called at the begining of each epoch where LCC is to be applied.
The dataloader has to yield batches that are tuples of tensors, the first of which is a 2D tensor of samples.
Warning:
If the construction of the loss object is distributed across multiple ranks, make sure that
dl
iterate over the WHOLE dataset (no distributed sampling).
49def louvain_clustering( 50 ds: BatchedTensorDataset, 51 k: int, 52 strategy: Strategy | Fabric | None = None, 53 n_features: int | None = None, 54 tqdm_style: TqdmStyle = None, 55 device: Any = "cpu", 56) -> np.ndarray: 57 """ 58 Args: 59 ds (BatchedTensorDataset): 60 k (int): 61 strategy (Strategy | Fabric | None, optional): Defaults to `None`, 62 meaning that the algorithm will not be parallelized. 63 n_features (int | None, optional): 64 tqdm_style (TqdmStyle, optional): 65 device (Any, optional): 66 """ 67 graph = knn_graph( 68 ds, k, strategy=strategy, n_features=n_features, tqdm_style=tqdm_style 69 ) 70 communities = _louvain_or_leiden(graph, device) 71 y_clst = [0] * graph.number_of_nodes() 72 for i_clst, clst in enumerate(communities): 73 for smpl in clst: 74 y_clst[smpl] = i_clst 75 return np.array(y_clst)
Arguments:
- ds (BatchedTensorDataset):
- k (int):
- strategy (Strategy | Fabric | None, optional): Defaults to
None
, meaning that the algorithm will not be parallelized. - n_features (int | None, optional):
- tqdm_style (TqdmStyle, optional):
- device (Any, optional):
235def max_connected_confusion_choice( 236 y_pred: ArrayLike, 237 y_true: ArrayLike, 238 n_classes: int, 239 n: int | None = None, 240 threshold: int = 0, 241) -> tuple[list[int], int]: 242 """ 243 Chooses the classes that are most confused for each other according to 244 some confusion graph scheme. 245 246 Args: 247 y_pred (Tensor): A `(N,)` int tensor or an `(N, n_classes)` 248 probabilities/logits float tensor 249 y_true (Tensor): A `(N,)` int tensor 250 n_classes (int): Number of classes in the dataset. 251 n (int, optional): Number of classes to choose. If `None`, returns the 252 classes in the largest connected component of the confusion graph. 253 threshold (int, optional): Ignore pairs of classes that are confused by 254 less than that number of samples. See also 255 `lcc.choice.confusion_graph`. 256 257 Returns: 258 An `int` list of `n` classes **or less**, and the total number of 259 confused sample number along these classes. 260 """ 261 cg = confusion_graph(y_pred, y_true, n_classes, threshold=threshold) 262 hcg, w = heaviest_connected_subgraph(cg, max_size=n, strict=False) 263 return Tensor(list(hcg.nodes)).int().tolist(), int(w)
Chooses the classes that are most confused for each other according to some confusion graph scheme.
Arguments:
- y_pred (Tensor): A
(N,)
int tensor or an(N, n_classes)
probabilities/logits float tensor - y_true (Tensor): A
(N,)
int tensor - n_classes (int): Number of classes in the dataset.
- n (int, optional): Number of classes to choose. If
None
, returns the classes in the largest connected component of the confusion graph. - threshold (int, optional): Ignore pairs of classes that are confused by
less than that number of samples. See also
lcc.choice.confusion_graph
.
Returns:
An
int
list ofn
classes or less, and the total number of confused sample number along these classes.
132def otm_matching_predicates( 133 y_a: ArrayLike, 134 y_b: ArrayLike, 135 matching: Matching, 136 c_a: int | None = None, 137) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 138 """ 139 Let `y_a` be `(N,)` integer array with values in $\\\\{ 0, 1, ..., c_a - 1 140 \\\\}$ (if the argument `c_a` is `None`, it is inferred to be `y_a.max() + 141 1`). If `y_a[i] == j`, then it is understood that the $i$-th sample (in 142 some dataset, say `x`) is in class $j$, which for disambiguation we'll call 143 the $a$-class $j$. 144 145 Likewise, let `y_b` be `(N,)` integer array with values $\\\\{ 0, 1, ..., 146 c_b - 1 \\\\}$. If `y_b[i] == j`, then it is understood that the $i$-th 147 sample `x[i]` is in $b$-class $j$. 148 149 Finally, let `matching` be a (possibley one-to-many) matching between the 150 $a$-classes and the $b$-classes. In other words each $a$-class corresponds to 151 some set of $b$-classes. 152 153 This method returns four boolean arrays with shape `(c_a, N)`, which in my 154 head I call *"true-louvain-miss-excess"*: 155 156 1. `p1` is simply given by `p1[a] = (y_a == a)`, or in other words, `p1[a, 157 i]` is `True` if and only if the $i$-th sample is in $a$-class `a`. 158 2. `p2[a, i]` is `True` if and only if the $i$-th sample is in a $b$-class 159 that has matched to $a$-class `a`. 160 3. `p3` is (informally) given by `p3[a] = (p1[a] and not p2[a])`. In other 161 words, `p3[a, i]` is `True` if sample $i$ is in $a$-class `a` but not in 162 any $b$-class matched with `a`. 163 4. `p4` is the "dual" of `p3`: `p4[a] = (p2[a] and not p1[a])`. In other 164 words, `p4[a, i]` is `True` if sample $i$ is not in $a$-class `a`, but is 165 in a $b$-class matched with `a`. 166 167 I hope this all makes sense. 168 169 Args: 170 y_a (np.ndarray): A `(N,)` integer array with values in $\\\\{ 0, 1, 171 ..., c_a - 1 \\\\}$ for some $c_a > 0$. 172 y_b (np.ndarray): A `(N,)` integer array with values in $\\\\{ 0, 1, 173 ..., c_b - 1 \\\\}$ for some $c_b > 0$. 174 matching (Matching): A partition of 175 $\\\\{ 0, ..., c_b - 1 \\\\}$ into $c_a$ sets. The $i$-th set is 176 understood to be the set of all classes of `y_b` that matched with 177 the $i$-th class of `y_a`. If some keys are strings, they must be 178 convertible to ints. This has probably been produced by 179 `lcc.correction.class_otm_matching`. 180 c_a (int | None, optional): Number of $a$-classes. Useful if `y_a` 181 does not contain all the possible classes of the dataset at hand. 182 If `None`, then `y_a` is assumed to contain all classes, and so `c_a 183 = y_a.max() + 1`. 184 """ 185 y_a, y_b = to_int_array(y_a), to_int_array(y_b) 186 matching = to_int_matching(matching) 187 if (la := len(y_a)) != (lb := len(y_b)): 188 raise ValueError( 189 f"y_a and y_b must have the same length, got {la} and {lb}" 190 ) 191 c_a = c_a or int(y_a.max() + 1) 192 p1 = [y_a == a for a in range(c_a)] 193 p2 = [ 194 ( 195 np.sum( 196 [np.zeros_like(y_b)] + [y_b == b for b in matching.get(a, [])], 197 axis=0, 198 ) 199 > 0 200 if a in matching 201 else np.full_like(y_b, False, dtype=bool) # a isn't matched in m 202 ) 203 for a in range(c_a) 204 ] 205 p3 = [p1[a] & ~p2[a] for a in range(c_a)] 206 p4 = [p2[a] & ~p1[a] for a in range(c_a)] 207 return np.array(p1), np.array(p2), np.array(p3), np.array(p4)
Let y_a
be (N,)
integer array with values in $\{ 0, 1, ..., c_a - 1
\}$ (if the argument c_a
is None
, it is inferred to be y_a.max() +
1
). If y_a[i] == j
, then it is understood that the $i$-th sample (in
some dataset, say x
) is in class $j$, which for disambiguation we'll call
the $a$-class $j$.
Likewise, let y_b
be (N,)
integer array with values $\{ 0, 1, ...,
c_b - 1 \}$. If y_b[i] == j
, then it is understood that the $i$-th
sample x[i]
is in $b$-class $j$.
Finally, let matching
be a (possibley one-to-many) matching between the
$a$-classes and the $b$-classes. In other words each $a$-class corresponds to
some set of $b$-classes.
This method returns four boolean arrays with shape (c_a, N)
, which in my
head I call "true-louvain-miss-excess":
p1
is simply given byp1[a] = (y_a == a)
, or in other words,p1[a, i]
isTrue
if and only if the $i$-th sample is in $a$-classa
.p2[a, i]
isTrue
if and only if the $i$-th sample is in a $b$-class that has matched to $a$-classa
.p3
is (informally) given byp3[a] = (p1[a] and not p2[a])
. In other words,p3[a, i]
isTrue
if sample $i$ is in $a$-classa
but not in any $b$-class matched witha
.p4
is the "dual" ofp3
:p4[a] = (p2[a] and not p1[a])
. In other words,p4[a, i]
isTrue
if sample $i$ is not in $a$-classa
, but is in a $b$-class matched witha
.
I hope this all makes sense.
Arguments:
- y_a (np.ndarray): A
(N,)
integer array with values in $\{ 0, 1, ..., c_a - 1 \}$ for some $c_a > 0$. - y_b (np.ndarray): A
(N,)
integer array with values in $\{ 0, 1, ..., c_b - 1 \}$ for some $c_b > 0$. - matching (Matching): A partition of
$\{ 0, ..., c_b - 1 \}$ into $c_a$ sets. The $i$-th set is
understood to be the set of all classes of
y_b
that matched with the $i$-th class ofy_a
. If some keys are strings, they must be convertible to ints. This has probably been produced bylcc.correction.class_otm_matching
. - c_a (int | None, optional): Number of $a$-classes. Useful if
y_a
does not contain all the possible classes of the dataset at hand. IfNone
, theny_a
is assumed to contain all classes, and soc_a = y_a.max() + 1
.
100def peer_pressure_clustering( 101 ds: BatchedTensorDataset, 102 k: int, 103 strategy: Strategy | Fabric | None = None, 104 n_features: int | None = None, 105 tqdm_style: TqdmStyle = None, 106) -> np.ndarray: 107 """ 108 Nearest-neighbor peer pressure clustering with weighted mean cluster 109 conductance as an objective function. 110 111 Args: 112 ds (BatchedTensorDataset): 113 k (int): 114 strategy (Strategy | Fabric | None, optional): Defaults to `None`, 115 meaning that the algorithm will not be parallelized. 116 n_features (int | None, optional): 117 tqdm_style (TqdmStyle, optional): 118 """ 119 graph = knn_graph( 120 ds, k, strategy=strategy, n_features=n_features, tqdm_style=tqdm_style 121 ) 122 a = nx.to_scipy_sparse_array(graph, format="csc") 123 y_clst, _ = _ppc(a, tqdm_style=tqdm_style) 124 return y_clst
Nearest-neighbor peer pressure clustering with weighted mean cluster conductance as an objective function.
Arguments:
- ds (BatchedTensorDataset):
- k (int):
- strategy (Strategy | Fabric | None, optional): Defaults to
None
, meaning that the algorithm will not be parallelized. - n_features (int | None, optional):
- tqdm_style (TqdmStyle, optional):
27class RandomizedLCCLoss(LCCLoss): 28 """ 29 A LCC loss function that pulls misclustered samples towards a CC sample in 30 the same class. 31 32 In principle, this implies some sort of exhaustive search since a MC sample 33 has to be compared to *every* CC sample in the same class. This is what 34 `lcc.correction.ExactLCCLoss` does. Here, to save on compute and time, only 35 a few CC samples are randomly selected in each cluster and used as potential 36 targets. 37 """ 38 39 ccspc: int 40 n_classes: int 41 targets: dict[int, Tensor] = {} 42 tqdm_style: TqdmStyle 43 matching: Matching 44 45 def __call__( 46 self, z: Tensor, y_true: ArrayLike, y_clst: ArrayLike 47 ) -> Tensor: 48 """ 49 Derives the clustering correction loss from a tensor of latent 50 representation `z` and dict of targets (see 51 `lcc.correction.lcc_targets`). 52 53 First, recall that the values of `target` (as produced 54 `lcc.correction.lcc_targets`) are `(k, d)` tensors, for some length 55 `k`. 56 57 Let's say `a` is a misclustered latent sample (a.k.a. a row of `z`) in 58 true class `i_true`, and that `(b_1, ..., b_k)` are the rows of 59 `targets[i_true]`. Then `a` contributes a term to the LCC loss equal to 60 the distance between `a` and the closest `b_j`, divided by 61 $\\\\sqrt{d}$. 62 63 It is possible that `i_true` is not in the keys of `targets`, in which 64 case the contribution of `a` to the LCC loss is zero. In particular, if 65 `targets` is empty, then the LCC loss is zero. 66 67 Args: 68 z (Tensor): The tensor of latent representations. *Do not* mask it 69 before passing it to this method. The correctly samples and the 70 missclustered samples are automatically separated. 71 y_true (ArrayLike): A `(N,)` integer array of true labels. 72 y_clst (ArrayLike): A `(N,)` integer array of the cluster labels. 73 """ 74 if not self.targets: 75 # ↓ actually need grad? 76 return torch.tensor(0.0, requires_grad=True).to(z.device) 77 z, y_true = z.flatten(1), to_int_tensor(y_true) 78 p_mc, _ = _mc_cc_predicates( 79 y_true, y_clst, self.matching, n_classes=self.n_classes 80 ) 81 sqrt_d, losses = sqrt(z.shape[-1]), [] 82 for i_true, p_mc_i_true in enumerate(p_mc): 83 if not ( 84 i_true in self.targets and len(self.targets[i_true]) > 0 85 ): # no targets in this true class 86 continue 87 if not p_mc_i_true.any(): # every sample is correctly clustered 88 continue 89 d = ( 90 torch.cdist(z[p_mc_i_true], self.targets[i_true].to(z.device)) 91 / sqrt_d 92 ) 93 losses.append(d.min(dim=-1).values) 94 if not losses: 95 return torch.tensor(0.0, requires_grad=True).to(z.device) 96 return torch.concat(losses).mean() 97 98 def __init__( 99 self, 100 n_classes: int, 101 ccspc: int = 100, 102 tqdm_style: TqdmStyle = None, 103 strategy: Strategy | Fabric | None = None, 104 ) -> None: 105 super().__init__(strategy=strategy) 106 self.n_classes, self.ccspc = n_classes, ccspc 107 self.tqdm_style = tqdm_style 108 109 def sync(self, **kwargs: Any) -> None: 110 if self.strategy is None: 111 return 112 path = self._get_tmp_dir() 113 gr = self.strategy.global_rank 114 st.save_file( 115 {str(k): v for k, v in self.targets.items()}, 116 path / f"targets.{gr}", 117 ) 118 self.strategy.barrier() 119 for r in range(self.strategy.world_size): 120 if r == self.strategy.global_rank: 121 continue # data from this rank is already in self.targets 122 data = st.load_file(path / f"targets.{r}") 123 self.targets.update({int(k): v for k, v in data.items()}) 124 return super().sync(**kwargs) 125 126 def update( 127 self, 128 dl: DataLoader, 129 y_true: ArrayLike, 130 y_clst: ArrayLike, 131 matching: Matching, 132 ) -> None: 133 """ 134 This method updates the `targets` attribute of this instance. It is a 135 dict containing the following: 136 - the keys are *among* true classes (unique values of `y_true`); let's 137 say that `i_true` is a key that owns `k` clusters; 138 - the associated value a `(n, d)` tensor, where `d` is the latent 139 dimension, whose rows are among correctly clustered samples in true 140 class `i_true`. For example, if `ccspc` is $1$, then `n` is the 141 number of clusters matched with `i_true`, say `k`. Otherwise, `n <= k 142 * ccspc`. 143 144 Under the hood, this method first choose the samples by their index 145 based on the "correctly clustered" predicate of `_mc_cc_predicates`. 146 Then, the whole dataset is iterated to collect the actual samples. 147 148 Args: 149 dl (DataLoader): An unsharded dataloader over a tensor dataset. 150 y_true (ArrayLike): A `(N,)` integer array. 151 y_clst (ArrayLike): A `(N,)` integer array. 152 matching (Matching): Produced by 153 `lcc.correction.class_otm_matching`. 154 """ 155 self.matching = to_int_matching(matching) 156 _, p_cc = _mc_cc_predicates( 157 y_true, y_clst, self.matching, self.n_classes 158 ) 159 # i_true (assigned to this rank) -> some indices of CC samples 160 indices: dict[int, set[int]] = { 161 i_true: set() for i_true in self._distribute_labels(y_true) 162 } 163 for i_true in indices: 164 for j_clst in self.matching[i_true]: 165 # p: (N,) CC in i_true and j_clst 166 p = p_cc[i_true] & (y_clst == j_clst) 167 s = np.random.choice(np.where(p)[0], self.ccspc) 168 indices[i_true].update(s) 169 n_seen, n_todo = 0, sum(len(v) for v in indices.values()) 170 # i_true (assigned to this rank) -> some CC samples 171 result: dict[int, list[Tensor]] = defaultdict(list) 172 tqdm = make_tqdm(self.tqdm_style) 173 progress = tqdm(dl, f"Finding correction targets (ccspc={self.ccspc})") 174 for z, *_ in progress: 175 for i_true, idxs in indices.items(): 176 lst = [idx for idx in idxs if n_seen <= idx < n_seen + len(z)] 177 for idx in lst: 178 result[i_true].append(z[idx - n_seen]) 179 n_todo -= 1 180 if n_todo <= 0: 181 break 182 n_seen += len(z) 183 if n_todo > 0: 184 logging.warning( 185 "Some correction targets could not be found " 186 "(n_seen={}, n_todo={})", 187 n_seen, 188 n_todo, 189 ) 190 self.targets = { 191 k: torch.stack(v).flatten(1) 192 for k, v in result.items() 193 if v # should already be is non-empty but just to make sure... 194 }
A LCC loss function that pulls misclustered samples towards a CC sample in the same class.
In principle, this implies some sort of exhaustive search since a MC sample
has to be compared to every CC sample in the same class. This is what
lcc.correction.ExactLCCLoss
does. Here, to save on compute and time, only
a few CC samples are randomly selected in each cluster and used as potential
targets.
109 def sync(self, **kwargs: Any) -> None: 110 if self.strategy is None: 111 return 112 path = self._get_tmp_dir() 113 gr = self.strategy.global_rank 114 st.save_file( 115 {str(k): v for k, v in self.targets.items()}, 116 path / f"targets.{gr}", 117 ) 118 self.strategy.barrier() 119 for r in range(self.strategy.world_size): 120 if r == self.strategy.global_rank: 121 continue # data from this rank is already in self.targets 122 data = st.load_file(path / f"targets.{r}") 123 self.targets.update({int(k): v for k, v in data.items()}) 124 return super().sync(**kwargs)
Distributes or shares this object's data across all ranks. Call this
after each ranks called update
. Is just a barrier by default.
126 def update( 127 self, 128 dl: DataLoader, 129 y_true: ArrayLike, 130 y_clst: ArrayLike, 131 matching: Matching, 132 ) -> None: 133 """ 134 This method updates the `targets` attribute of this instance. It is a 135 dict containing the following: 136 - the keys are *among* true classes (unique values of `y_true`); let's 137 say that `i_true` is a key that owns `k` clusters; 138 - the associated value a `(n, d)` tensor, where `d` is the latent 139 dimension, whose rows are among correctly clustered samples in true 140 class `i_true`. For example, if `ccspc` is $1$, then `n` is the 141 number of clusters matched with `i_true`, say `k`. Otherwise, `n <= k 142 * ccspc`. 143 144 Under the hood, this method first choose the samples by their index 145 based on the "correctly clustered" predicate of `_mc_cc_predicates`. 146 Then, the whole dataset is iterated to collect the actual samples. 147 148 Args: 149 dl (DataLoader): An unsharded dataloader over a tensor dataset. 150 y_true (ArrayLike): A `(N,)` integer array. 151 y_clst (ArrayLike): A `(N,)` integer array. 152 matching (Matching): Produced by 153 `lcc.correction.class_otm_matching`. 154 """ 155 self.matching = to_int_matching(matching) 156 _, p_cc = _mc_cc_predicates( 157 y_true, y_clst, self.matching, self.n_classes 158 ) 159 # i_true (assigned to this rank) -> some indices of CC samples 160 indices: dict[int, set[int]] = { 161 i_true: set() for i_true in self._distribute_labels(y_true) 162 } 163 for i_true in indices: 164 for j_clst in self.matching[i_true]: 165 # p: (N,) CC in i_true and j_clst 166 p = p_cc[i_true] & (y_clst == j_clst) 167 s = np.random.choice(np.where(p)[0], self.ccspc) 168 indices[i_true].update(s) 169 n_seen, n_todo = 0, sum(len(v) for v in indices.values()) 170 # i_true (assigned to this rank) -> some CC samples 171 result: dict[int, list[Tensor]] = defaultdict(list) 172 tqdm = make_tqdm(self.tqdm_style) 173 progress = tqdm(dl, f"Finding correction targets (ccspc={self.ccspc})") 174 for z, *_ in progress: 175 for i_true, idxs in indices.items(): 176 lst = [idx for idx in idxs if n_seen <= idx < n_seen + len(z)] 177 for idx in lst: 178 result[i_true].append(z[idx - n_seen]) 179 n_todo -= 1 180 if n_todo <= 0: 181 break 182 n_seen += len(z) 183 if n_todo > 0: 184 logging.warning( 185 "Some correction targets could not be found " 186 "(n_seen={}, n_todo={})", 187 n_seen, 188 n_todo, 189 ) 190 self.targets = { 191 k: torch.stack(v).flatten(1) 192 for k, v in result.items() 193 if v # should already be is non-empty but just to make sure... 194 }
This method updates the targets
attribute of this instance. It is a
dict containing the following:
- the keys are among true classes (unique values of
y_true
); let's say thati_true
is a key that ownsk
clusters; - the associated value a
(n, d)
tensor, whered
is the latent dimension, whose rows are among correctly clustered samples in true classi_true
. For example, ifccspc
is $1$, thenn
is the number of clusters matched withi_true
, sayk
. Otherwise,n <= k <ul> <li>ccspc
.
Under the hood, this method first choose the samples by their index
based on the "correctly clustered" predicate of _mc_cc_predicates
.
Then, the whole dataset is iterated to collect the actual samples.
Arguments:
- dl (DataLoader): An unsharded dataloader over a tensor dataset.
- y_true (ArrayLike): A
(N,)
integer array. - y_clst (ArrayLike): A
(N,)
integer array. - matching (Matching): Produced by
lcc.correction.class_otm_matching
.
266def top_confusion_pairs( 267 y_pred: ArrayLike, 268 y_true: ArrayLike, 269 n_classes: int, 270 n_pairs: int | None = None, 271 threshold: int = 0, 272) -> list[tuple[int, int]]: 273 """ 274 Returns the top `n_pairs` top pairs of labels that exhibit the most 275 confusion. The confusion between two labels $a$ and $b$ is the number of 276 samples in true class $a$ that are predicted as class $b$, plus the number 277 of samples in true class $b$ that are predicted as class $a$. 278 279 Example: 280 >>> y_pred, y_true = [0, 0, 1, 1, 2, 2], [0, 1, 1, 2, 2, 0] 281 >>> top_confusion_pairs(y_pred, y_true, n_classes=3, n_pairs=2) 282 [(1, 2), (0, 2)] 283 284 Args: 285 y_pred (Tensor): A `(N,)` int tensor or an `(N, n_classes)` 286 probabilities/logits float tensor 287 y_true (Tensor): A `(N,)` int tensor 288 n_classes (int): Number of classes in the dataset 289 n_pairs (int | None, optional): Number of desired pairs. The actual 290 result might have less. If `None`, returns all pairs of labels that 291 have at lease `threshold` confused samples. 292 threshold (int, optional): Minimum number of confused samples between a 293 pair of labels to be included in the list 294 295 Returns: 296 The top `n_pairs` pairs **or less** of labels that exhibit the most 297 confusion. 298 """ 299 y_pred, y_true = to_int_tensor(y_pred), to_int_tensor(y_true) 300 cm = multiclass_confusion_matrix(y_pred, y_true, n_classes).numpy() 301 cm = cm + cm.T # Confusion in either direction 302 cm = cm * (1 - np.eye(len(cm))) # Remove the diagonal 303 idx = cm.argsort(axis=None) # Flat indices 304 idx = np.flip(idx) 305 cp = np.stack(np.unravel_index(idx, cm.shape)).T 306 lst = [(i, j) for i, j in cp if cm[i, j] > threshold and i < j] 307 return lst if n_pairs is None else lst[:n_pairs]
Returns the top n_pairs
top pairs of labels that exhibit the most
confusion. The confusion between two labels $a$ and $b$ is the number of
samples in true class $a$ that are predicted as class $b$, plus the number
of samples in true class $b$ that are predicted as class $a$.
Example:
>>> y_pred, y_true = [0, 0, 1, 1, 2, 2], [0, 1, 1, 2, 2, 0] >>> top_confusion_pairs(y_pred, y_true, n_classes=3, n_pairs=2) [(1, 2), (0, 2)]
Arguments:
- y_pred (Tensor): A
(N,)
int tensor or an(N, n_classes)
probabilities/logits float tensor - y_true (Tensor): A
(N,)
int tensor - n_classes (int): Number of classes in the dataset
- n_pairs (int | None, optional): Number of desired pairs. The actual
result might have less. If
None
, returns all pairs of labels that have at leasethreshold
confused samples. - threshold (int, optional): Minimum number of confused samples between a pair of labels to be included in the list
Returns:
The top
n_pairs
pairs or less of labels that exhibit the most confusion.