lcc.plotting
Plotting utilities
1"""Plotting utilities""" 2 3from pathlib import Path 4from typing import Any 5 6import bokeh.layouts as bkl 7import bokeh.models as bkm 8import bokeh.palettes as bkp 9import bokeh.plotting as bk 10import numpy as np 11from loguru import logger as logging 12from numpy.typing import ArrayLike 13from sklearn.preprocessing import RobustScaler 14 15from .correction import otm_matching_predicates 16from .correction.utils import Matching, to_int_matching 17from .utils import to_array, to_int_array 18 19BK_PALETTE_FUNCTIONS = { 20 "cividis": bkp.cividis, 21 "gray": bkp.gray, 22 "grey": bkp.grey, 23 "inferno": bkp.inferno, 24 "magma": bkp.magma, 25 "viridis": bkp.viridis, 26} 27""" 28Supported bokeh palettes. See also 29https://docs.bokeh.org/en/latest/docs/reference/palettes.html#functions. 30""" 31 32 33def class_scatter( 34 figure: bk.figure, 35 x: ArrayLike, 36 y: ArrayLike, 37 palette: bkp.Palette | list[str] | str | None = None, 38 size: float = 3, 39 rescale: bool = True, 40 axis_visible: bool = False, 41 grid_visible: bool = True, 42 outliers: bool = True, 43) -> None: 44 """ 45 Scatter plot where each class has a different color. Points in negative 46 classes (those for which the `y` value is strictly less than 0), called 47 *outliers* here, are all plotted black. 48 49 Example: 50 51  52 53 (this example does't have outliers but I'm sure you can use your 54 imagination) 55 56 Warning: 57 This method hides the figure's axis and grid lines. 58 59 Args: 60 figure (bk.figure): 61 x (np.ndarray): A `(N, 2)` array 62 y (np.ndarray): A `(N,)` int array. Each unique value corresponds to a 63 class 64 palette (Palette | list[str] | str | None, optional): Either a 65 palette object (see 66 https://docs.bokeh.org/en/latest/docs/reference/palettes.html#bokeh-palettes 67 ), 68 a list of HTML colors (at least as many as the number of classes), 69 or a name in `lcc.plotting.BK_PALETTE_FUNCTIONS`. 70 size (float, optional): Dot size. The outlier's dot size will be half 71 that. 72 rescale (bool, optional): Whether to rescale the `x` values to $[0, 1]$. 73 outliers (bool, optional): Whether to plot the outliers (those samples 74 with a label < 0). 75 76 Raises: 77 `ValueError` if the palette is unknown 78 """ 79 x, y = to_array(x), to_int_array(y) 80 if rescale: 81 x = RobustScaler().fit_transform(x) 82 assert isinstance(x, np.ndarray) # for typechecking 83 n_classes = min(len(np.unique(y[y >= 0])), 256) 84 if palette is None: 85 palette = bkp.viridis(n_classes) 86 if isinstance(palette, str): 87 if palette not in BK_PALETTE_FUNCTIONS: 88 raise ValueError(f"Unknown palette '{palette}'") 89 palette = BK_PALETTE_FUNCTIONS[palette](n_classes) 90 for i, j in enumerate(np.unique(y[y >= 0])[:n_classes]): 91 if not (y == j).any(): 92 continue 93 a = x[y == j] 94 assert isinstance(a, np.ndarray) # for typechecking... 95 figure.scatter( 96 a[:, 0], 97 a[:, 1], 98 color=palette[i], 99 line_width=0, 100 size=size, 101 ) 102 if (y < 0).any() and outliers: 103 a = x[y < 0] 104 assert isinstance(a, np.ndarray) # for typechecking... 105 figure.scatter( 106 a[:, 0], 107 a[:, 1], 108 color="black", 109 line_width=0, 110 size=size / 2, 111 ) 112 figure.axis.visible = axis_visible 113 figure.xgrid.visible = figure.ygrid.visible = grid_visible 114 115 116def export_png(obj: Any, filename: str | Path) -> Path: 117 """ 118 A replacement for `bokeh.io.export_png` which can sometimes be a bit buggy. 119 Instanciates its own Firefox webdriver. A bit slower but more reliable. 120 121 If Selenium is not installed, or if the Firefox webdriver is not installed, 122 or if any other error occurs, this method will **silently** fall back to 123 the default bokeh implementation. 124 """ 125 from bokeh.io import export_png as _export_png 126 127 webdriver: Any = None 128 try: 129 from selenium.webdriver import Firefox, FirefoxOptions 130 131 opts = FirefoxOptions() 132 opts.add_argument("--headless") 133 webdriver = Firefox(options=opts) 134 _export_png(obj, filename=str(filename), webdriver=webdriver) 135 except Exception as e: 136 if isinstance(e, ModuleNotFoundError): 137 logging.error( 138 "Selenium is not installed. Falling back to default bokeh " 139 "implementation" 140 ) 141 else: 142 logging.error( 143 f"Failed to export PNG using explicit Selenium driver: {e}\n" 144 "Falling back to default bokeh implementation" 145 ) 146 _export_png(obj, filename=str(filename)) 147 finally: 148 if webdriver is not None: 149 webdriver.close() 150 return Path(filename) 151 152 153def class_matching_plot( 154 x: ArrayLike, 155 y_true: ArrayLike, 156 y_clst: ArrayLike, 157 matching: Matching, 158 size: int = 400, 159) -> bkm.GridBox: 160 """ 161 Given a dataset `x` and two labellings `y_true` and `y_clst`, this method 162 makes a scatter plot detailling the situation. Labels in `y_true` are 163 considered to be 164 ground truth. 165 166 Example: 167  168 169 Warning: 170 The array `x` is always rescaled to fit in the $[0, 1]$ range. 171 172 Args: 173 x: (ArrayLike): A `(N, 2)` array 174 y_true (ArrayLike): A `(N,)` integer array with 175 values in $\\\\{ 0, 1, ..., c_a - 1 \\\\}$ for some $c_a > 0$. 176 y_clst (ArrayLike): A `(N,)` integer array with 177 values in $\\\\{ 0, 1, ..., c_b - 1 \\\\}$ for some $c_b > 0$. 178 matching (Matching): Matching between 179 the labels of `y_true` and the labels of `y_clst`. If some keys are 180 strings, they must be convertible to ints. Probably generated from 181 `lcc.correction.class_otm_matching`. 182 size (int, optional): The size of each scatter plot. Defaults to 400. 183 """ 184 x, y_true, y_clst = to_array(x), to_int_array(y_true), to_int_array(y_clst) 185 x = RobustScaler().fit_transform(x) 186 assert isinstance(x, np.ndarray) # for typechecking 187 matching = to_int_matching(matching) 188 p1, p2, p3, p4 = otm_matching_predicates(y_true, y_clst, matching) 189 n_true, n_matched = p1.sum(axis=1), p2.sum(axis=1) 190 n_inter = (p1 & p2).sum(axis=1) 191 n_miss, n_exc = p3.sum(axis=1), p4.sum(axis=1) 192 193 palt_true = bkp.viridis(len(np.unique(y_true))) 194 palt_clst = bkp.viridis(len(np.unique(y_clst))) 195 figures = [] 196 kw = {"width": size, "height": size} 197 for a, bs in matching.items(): 198 n_true, n_matched = p1[a].sum(), p2[a].sum() 199 n_inter = (p1[a] & p2[a]).sum() 200 n_miss, n_exc = p3[a].sum(), p4[a].sum() 201 fig_a = bk.figure(title=f"Ground truth, class {a}; n = {n_true}", **kw) 202 class_scatter( 203 fig_a, 204 x[p1[a]], 205 y_true[p1[a]], 206 rescale=False, 207 palette=[palt_true[a]], 208 ) 209 if n_matched == 0: 210 figures.append([fig_a, None, None, None]) 211 continue 212 fig_b = bk.figure( 213 title=( 214 f"{len(bs)} matched classes" 215 + ", ".join(map(str, bs)) 216 + f"; n = {n_matched}" 217 ), 218 **kw, 219 ) 220 class_scatter( 221 fig_b, 222 x[p2[a]], 223 y_clst[p2[a]], 224 rescale=False, 225 palette=[palt_clst[b] for b in np.unique(y_clst[p2[a]])], 226 ) 227 fig_match = bk.figure(title=f"Intersection; n = {n_inter}", **kw) 228 class_scatter( 229 fig_match, 230 x[p1[a] & p2[a]], 231 y_clst[p1[a] & p2[a]], 232 rescale=False, 233 palette=[palt_clst[b] for b in np.unique(y_clst[p1[a] & p2[a]])], 234 ) 235 y_diff = p3[a] + 2 * p4[a] 236 fig_diff = bk.figure( 237 title=( 238 f"Symmetric difference; n = {n_miss + n_exc}\n" 239 f"Misses (red) = {n_miss}; excess (blue) = {n_exc}" 240 ), 241 **kw, 242 ) 243 class_scatter( 244 fig_diff, 245 x[y_diff > 0], 246 y_diff[y_diff > 0], 247 palette=["#ff0000", "#0000ff"], 248 rescale=False, 249 ) 250 make_same_xy_range(fig_a, fig_b, fig_match, fig_diff) 251 figures.append([fig_a, fig_b, fig_match, fig_diff]) 252 253 return bkl.grid(figures) # type: ignore 254 255 256def make_same_xy_range(*args: bk.figure) -> None: 257 """ 258 Makes sure all figures share the same `x_range` and `y_range`. As a result, 259 if a figure is zoomed in or dragged, all others will be too and by the same 260 amount. 261 """ 262 for f in args[1:]: 263 f.x_range, f.y_range = args[0].x_range, args[0].y_range
Supported bokeh palettes. See also https://docs.bokeh.org/en/latest/docs/reference/palettes.html#functions.
34def class_scatter( 35 figure: bk.figure, 36 x: ArrayLike, 37 y: ArrayLike, 38 palette: bkp.Palette | list[str] | str | None = None, 39 size: float = 3, 40 rescale: bool = True, 41 axis_visible: bool = False, 42 grid_visible: bool = True, 43 outliers: bool = True, 44) -> None: 45 """ 46 Scatter plot where each class has a different color. Points in negative 47 classes (those for which the `y` value is strictly less than 0), called 48 *outliers* here, are all plotted black. 49 50 Example: 51 52  53 54 (this example does't have outliers but I'm sure you can use your 55 imagination) 56 57 Warning: 58 This method hides the figure's axis and grid lines. 59 60 Args: 61 figure (bk.figure): 62 x (np.ndarray): A `(N, 2)` array 63 y (np.ndarray): A `(N,)` int array. Each unique value corresponds to a 64 class 65 palette (Palette | list[str] | str | None, optional): Either a 66 palette object (see 67 https://docs.bokeh.org/en/latest/docs/reference/palettes.html#bokeh-palettes 68 ), 69 a list of HTML colors (at least as many as the number of classes), 70 or a name in `lcc.plotting.BK_PALETTE_FUNCTIONS`. 71 size (float, optional): Dot size. The outlier's dot size will be half 72 that. 73 rescale (bool, optional): Whether to rescale the `x` values to $[0, 1]$. 74 outliers (bool, optional): Whether to plot the outliers (those samples 75 with a label < 0). 76 77 Raises: 78 `ValueError` if the palette is unknown 79 """ 80 x, y = to_array(x), to_int_array(y) 81 if rescale: 82 x = RobustScaler().fit_transform(x) 83 assert isinstance(x, np.ndarray) # for typechecking 84 n_classes = min(len(np.unique(y[y >= 0])), 256) 85 if palette is None: 86 palette = bkp.viridis(n_classes) 87 if isinstance(palette, str): 88 if palette not in BK_PALETTE_FUNCTIONS: 89 raise ValueError(f"Unknown palette '{palette}'") 90 palette = BK_PALETTE_FUNCTIONS[palette](n_classes) 91 for i, j in enumerate(np.unique(y[y >= 0])[:n_classes]): 92 if not (y == j).any(): 93 continue 94 a = x[y == j] 95 assert isinstance(a, np.ndarray) # for typechecking... 96 figure.scatter( 97 a[:, 0], 98 a[:, 1], 99 color=palette[i], 100 line_width=0, 101 size=size, 102 ) 103 if (y < 0).any() and outliers: 104 a = x[y < 0] 105 assert isinstance(a, np.ndarray) # for typechecking... 106 figure.scatter( 107 a[:, 0], 108 a[:, 1], 109 color="black", 110 line_width=0, 111 size=size / 2, 112 ) 113 figure.axis.visible = axis_visible 114 figure.xgrid.visible = figure.ygrid.visible = grid_visible
Scatter plot where each class has a different color. Points in negative
classes (those for which the y
value is strictly less than 0), called
outliers here, are all plotted black.
Example:
(this example does't have outliers but I'm sure you can use your imagination)
Warning:
This method hides the figure's axis and grid lines.
Arguments:
- figure (bk.figure):
- x (np.ndarray): A
(N, 2)
array - y (np.ndarray): A
(N,)
int array. Each unique value corresponds to a class - palette (Palette | list[str] | str | None, optional): Either a
palette object (see
https://docs.bokeh.org/en/latest/docs/reference/palettes.html#bokeh-palettes
),
a list of HTML colors (at least as many as the number of classes),
or a name in
lcc.plotting.BK_PALETTE_FUNCTIONS
. - size (float, optional): Dot size. The outlier's dot size will be half that.
- rescale (bool, optional): Whether to rescale the
x
values to $[0, 1]$. - outliers (bool, optional): Whether to plot the outliers (those samples with a label < 0).
Raises:
ValueError
if the palette is unknown
117def export_png(obj: Any, filename: str | Path) -> Path: 118 """ 119 A replacement for `bokeh.io.export_png` which can sometimes be a bit buggy. 120 Instanciates its own Firefox webdriver. A bit slower but more reliable. 121 122 If Selenium is not installed, or if the Firefox webdriver is not installed, 123 or if any other error occurs, this method will **silently** fall back to 124 the default bokeh implementation. 125 """ 126 from bokeh.io import export_png as _export_png 127 128 webdriver: Any = None 129 try: 130 from selenium.webdriver import Firefox, FirefoxOptions 131 132 opts = FirefoxOptions() 133 opts.add_argument("--headless") 134 webdriver = Firefox(options=opts) 135 _export_png(obj, filename=str(filename), webdriver=webdriver) 136 except Exception as e: 137 if isinstance(e, ModuleNotFoundError): 138 logging.error( 139 "Selenium is not installed. Falling back to default bokeh " 140 "implementation" 141 ) 142 else: 143 logging.error( 144 f"Failed to export PNG using explicit Selenium driver: {e}\n" 145 "Falling back to default bokeh implementation" 146 ) 147 _export_png(obj, filename=str(filename)) 148 finally: 149 if webdriver is not None: 150 webdriver.close() 151 return Path(filename)
A replacement for bokeh.io.export_png
which can sometimes be a bit buggy.
Instanciates its own Firefox webdriver. A bit slower but more reliable.
If Selenium is not installed, or if the Firefox webdriver is not installed, or if any other error occurs, this method will silently fall back to the default bokeh implementation.
154def class_matching_plot( 155 x: ArrayLike, 156 y_true: ArrayLike, 157 y_clst: ArrayLike, 158 matching: Matching, 159 size: int = 400, 160) -> bkm.GridBox: 161 """ 162 Given a dataset `x` and two labellings `y_true` and `y_clst`, this method 163 makes a scatter plot detailling the situation. Labels in `y_true` are 164 considered to be 165 ground truth. 166 167 Example: 168  169 170 Warning: 171 The array `x` is always rescaled to fit in the $[0, 1]$ range. 172 173 Args: 174 x: (ArrayLike): A `(N, 2)` array 175 y_true (ArrayLike): A `(N,)` integer array with 176 values in $\\\\{ 0, 1, ..., c_a - 1 \\\\}$ for some $c_a > 0$. 177 y_clst (ArrayLike): A `(N,)` integer array with 178 values in $\\\\{ 0, 1, ..., c_b - 1 \\\\}$ for some $c_b > 0$. 179 matching (Matching): Matching between 180 the labels of `y_true` and the labels of `y_clst`. If some keys are 181 strings, they must be convertible to ints. Probably generated from 182 `lcc.correction.class_otm_matching`. 183 size (int, optional): The size of each scatter plot. Defaults to 400. 184 """ 185 x, y_true, y_clst = to_array(x), to_int_array(y_true), to_int_array(y_clst) 186 x = RobustScaler().fit_transform(x) 187 assert isinstance(x, np.ndarray) # for typechecking 188 matching = to_int_matching(matching) 189 p1, p2, p3, p4 = otm_matching_predicates(y_true, y_clst, matching) 190 n_true, n_matched = p1.sum(axis=1), p2.sum(axis=1) 191 n_inter = (p1 & p2).sum(axis=1) 192 n_miss, n_exc = p3.sum(axis=1), p4.sum(axis=1) 193 194 palt_true = bkp.viridis(len(np.unique(y_true))) 195 palt_clst = bkp.viridis(len(np.unique(y_clst))) 196 figures = [] 197 kw = {"width": size, "height": size} 198 for a, bs in matching.items(): 199 n_true, n_matched = p1[a].sum(), p2[a].sum() 200 n_inter = (p1[a] & p2[a]).sum() 201 n_miss, n_exc = p3[a].sum(), p4[a].sum() 202 fig_a = bk.figure(title=f"Ground truth, class {a}; n = {n_true}", **kw) 203 class_scatter( 204 fig_a, 205 x[p1[a]], 206 y_true[p1[a]], 207 rescale=False, 208 palette=[palt_true[a]], 209 ) 210 if n_matched == 0: 211 figures.append([fig_a, None, None, None]) 212 continue 213 fig_b = bk.figure( 214 title=( 215 f"{len(bs)} matched classes" 216 + ", ".join(map(str, bs)) 217 + f"; n = {n_matched}" 218 ), 219 **kw, 220 ) 221 class_scatter( 222 fig_b, 223 x[p2[a]], 224 y_clst[p2[a]], 225 rescale=False, 226 palette=[palt_clst[b] for b in np.unique(y_clst[p2[a]])], 227 ) 228 fig_match = bk.figure(title=f"Intersection; n = {n_inter}", **kw) 229 class_scatter( 230 fig_match, 231 x[p1[a] & p2[a]], 232 y_clst[p1[a] & p2[a]], 233 rescale=False, 234 palette=[palt_clst[b] for b in np.unique(y_clst[p1[a] & p2[a]])], 235 ) 236 y_diff = p3[a] + 2 * p4[a] 237 fig_diff = bk.figure( 238 title=( 239 f"Symmetric difference; n = {n_miss + n_exc}\n" 240 f"Misses (red) = {n_miss}; excess (blue) = {n_exc}" 241 ), 242 **kw, 243 ) 244 class_scatter( 245 fig_diff, 246 x[y_diff > 0], 247 y_diff[y_diff > 0], 248 palette=["#ff0000", "#0000ff"], 249 rescale=False, 250 ) 251 make_same_xy_range(fig_a, fig_b, fig_match, fig_diff) 252 figures.append([fig_a, fig_b, fig_match, fig_diff]) 253 254 return bkl.grid(figures) # type: ignore
Given a dataset x
and two labellings y_true
and y_clst
, this method
makes a scatter plot detailling the situation. Labels in y_true
are
considered to be
ground truth.
Example:
Warning:
The array
x
is always rescaled to fit in the $[0, 1]$ range.
Arguments:
- x: (ArrayLike): A
(N, 2)
array - y_true (ArrayLike): A
(N,)
integer array with values in $\{ 0, 1, ..., c_a - 1 \}$ for some $c_a > 0$. - y_clst (ArrayLike): A
(N,)
integer array with values in $\{ 0, 1, ..., c_b - 1 \}$ for some $c_b > 0$. - matching (Matching): Matching between
the labels of
y_true
and the labels ofy_clst
. If some keys are strings, they must be convertible to ints. Probably generated fromlcc.correction.class_otm_matching
. - size (int, optional): The size of each scatter plot. Defaults to 400.
257def make_same_xy_range(*args: bk.figure) -> None: 258 """ 259 Makes sure all figures share the same `x_range` and `y_range`. As a result, 260 if a figure is zoomed in or dragged, all others will be too and by the same 261 amount. 262 """ 263 for f in args[1:]: 264 f.x_range, f.y_range = args[0].x_range, args[0].y_range
Makes sure all figures share the same x_range
and y_range
. As a result,
if a figure is zoomed in or dragged, all others will be too and by the same
amount.