lcc.plotting

Plotting utilities

  1"""Plotting utilities"""
  2
  3from pathlib import Path
  4from typing import Any
  5
  6import bokeh.layouts as bkl
  7import bokeh.models as bkm
  8import bokeh.palettes as bkp
  9import bokeh.plotting as bk
 10import numpy as np
 11from loguru import logger as logging
 12from numpy.typing import ArrayLike
 13from sklearn.preprocessing import RobustScaler
 14
 15from .correction import otm_matching_predicates
 16from .correction.utils import Matching, to_int_matching
 17from .utils import to_array, to_int_array
 18
 19BK_PALETTE_FUNCTIONS = {
 20    "cividis": bkp.cividis,
 21    "gray": bkp.gray,
 22    "grey": bkp.grey,
 23    "inferno": bkp.inferno,
 24    "magma": bkp.magma,
 25    "viridis": bkp.viridis,
 26}
 27"""
 28Supported bokeh palettes. See also
 29https://docs.bokeh.org/en/latest/docs/reference/palettes.html#functions.
 30"""
 31
 32
 33def class_scatter(
 34    figure: bk.figure,
 35    x: ArrayLike,
 36    y: ArrayLike,
 37    palette: bkp.Palette | list[str] | str | None = None,
 38    size: float = 3,
 39    rescale: bool = True,
 40    axis_visible: bool = False,
 41    grid_visible: bool = True,
 42    outliers: bool = True,
 43) -> None:
 44    """
 45    Scatter plot where each class has a different color. Points in negative
 46    classes (those for which the `y` value is strictly less than 0), called
 47    *outliers* here, are all plotted black.
 48
 49    Example:
 50
 51        ![Example 1](../docs/imgs/class_scatter.png)
 52
 53        (this example does't have outliers but I'm sure you can use your
 54        imagination)
 55
 56    Warning:
 57        This method hides the figure's axis and grid lines.
 58
 59    Args:
 60        figure (bk.figure):
 61        x (np.ndarray): A `(N, 2)` array
 62        y (np.ndarray): A `(N,)` int array. Each unique value corresponds to a
 63            class
 64        palette (Palette | list[str] | str | None, optional): Either a
 65            palette object (see
 66            https://docs.bokeh.org/en/latest/docs/reference/palettes.html#bokeh-palettes
 67            ),
 68            a list of HTML colors (at least as many as the number of classes),
 69            or a name in `lcc.plotting.BK_PALETTE_FUNCTIONS`.
 70        size (float, optional): Dot size. The outlier's dot size will be half
 71            that.
 72        rescale (bool, optional): Whether to rescale the `x` values to $[0, 1]$.
 73        outliers (bool, optional): Whether to plot the outliers (those samples
 74            with a label < 0).
 75
 76    Raises:
 77        `ValueError` if the palette is unknown
 78    """
 79    x, y = to_array(x), to_int_array(y)
 80    if rescale:
 81        x = RobustScaler().fit_transform(x)
 82    assert isinstance(x, np.ndarray)  # for typechecking
 83    n_classes = min(len(np.unique(y[y >= 0])), 256)
 84    if palette is None:
 85        palette = bkp.viridis(n_classes)
 86    if isinstance(palette, str):
 87        if palette not in BK_PALETTE_FUNCTIONS:
 88            raise ValueError(f"Unknown palette '{palette}'")
 89        palette = BK_PALETTE_FUNCTIONS[palette](n_classes)
 90    for i, j in enumerate(np.unique(y[y >= 0])[:n_classes]):
 91        if not (y == j).any():
 92            continue
 93        a = x[y == j]
 94        assert isinstance(a, np.ndarray)  # for typechecking...
 95        figure.scatter(
 96            a[:, 0],
 97            a[:, 1],
 98            color=palette[i],
 99            line_width=0,
100            size=size,
101        )
102    if (y < 0).any() and outliers:
103        a = x[y < 0]
104        assert isinstance(a, np.ndarray)  # for typechecking...
105        figure.scatter(
106            a[:, 0],
107            a[:, 1],
108            color="black",
109            line_width=0,
110            size=size / 2,
111        )
112    figure.axis.visible = axis_visible
113    figure.xgrid.visible = figure.ygrid.visible = grid_visible
114
115
116def export_png(obj: Any, filename: str | Path) -> Path:
117    """
118    A replacement for `bokeh.io.export_png` which can sometimes be a bit buggy.
119    Instanciates its own Firefox webdriver. A bit slower but more reliable.
120
121    If Selenium is not installed, or if the Firefox webdriver is not installed,
122    or if any other error occurs, this method will **silently** fall back to
123    the default bokeh implementation.
124    """
125    from bokeh.io import export_png as _export_png
126
127    webdriver: Any = None
128    try:
129        from selenium.webdriver import Firefox, FirefoxOptions
130
131        opts = FirefoxOptions()
132        opts.add_argument("--headless")
133        webdriver = Firefox(options=opts)
134        _export_png(obj, filename=str(filename), webdriver=webdriver)
135    except Exception as e:
136        if isinstance(e, ModuleNotFoundError):
137            logging.error(
138                "Selenium is not installed. Falling back to default bokeh "
139                "implementation"
140            )
141        else:
142            logging.error(
143                f"Failed to export PNG using explicit Selenium driver: {e}\n"
144                "Falling back to default bokeh implementation"
145            )
146        _export_png(obj, filename=str(filename))
147    finally:
148        if webdriver is not None:
149            webdriver.close()
150    return Path(filename)
151
152
153def class_matching_plot(
154    x: ArrayLike,
155    y_true: ArrayLike,
156    y_clst: ArrayLike,
157    matching: Matching,
158    size: int = 400,
159) -> bkm.GridBox:
160    """
161    Given a dataset `x` and two labellings `y_true` and `y_clst`, this method
162    makes a scatter plot detailling the situation. Labels in `y_true` are
163    considered to be
164    ground truth.
165
166    Example:
167        ![Example 1](../docs/imgs/class_matching_plot.png)
168
169    Warning:
170        The array `x` is always rescaled to fit in the $[0, 1]$ range.
171
172    Args:
173        x: (ArrayLike): A `(N, 2)` array
174        y_true (ArrayLike): A `(N,)` integer array with
175            values in $\\\\{ 0, 1, ..., c_a - 1 \\\\}$ for some $c_a > 0$.
176        y_clst (ArrayLike): A `(N,)` integer array with
177            values in $\\\\{ 0, 1, ..., c_b - 1 \\\\}$ for some $c_b > 0$.
178        matching (Matching): Matching between
179            the labels of `y_true` and the labels of `y_clst`. If some keys are
180            strings, they must be convertible to ints. Probably generated from
181            `lcc.correction.class_otm_matching`.
182        size (int, optional): The size of each scatter plot. Defaults to 400.
183    """
184    x, y_true, y_clst = to_array(x), to_int_array(y_true), to_int_array(y_clst)
185    x = RobustScaler().fit_transform(x)
186    assert isinstance(x, np.ndarray)  # for typechecking
187    matching = to_int_matching(matching)
188    p1, p2, p3, p4 = otm_matching_predicates(y_true, y_clst, matching)
189    n_true, n_matched = p1.sum(axis=1), p2.sum(axis=1)
190    n_inter = (p1 & p2).sum(axis=1)
191    n_miss, n_exc = p3.sum(axis=1), p4.sum(axis=1)
192
193    palt_true = bkp.viridis(len(np.unique(y_true)))
194    palt_clst = bkp.viridis(len(np.unique(y_clst)))
195    figures = []
196    kw = {"width": size, "height": size}
197    for a, bs in matching.items():
198        n_true, n_matched = p1[a].sum(), p2[a].sum()
199        n_inter = (p1[a] & p2[a]).sum()
200        n_miss, n_exc = p3[a].sum(), p4[a].sum()
201        fig_a = bk.figure(title=f"Ground truth, class {a}; n = {n_true}", **kw)
202        class_scatter(
203            fig_a,
204            x[p1[a]],
205            y_true[p1[a]],
206            rescale=False,
207            palette=[palt_true[a]],
208        )
209        if n_matched == 0:
210            figures.append([fig_a, None, None, None])
211            continue
212        fig_b = bk.figure(
213            title=(
214                f"{len(bs)} matched classes"
215                + ", ".join(map(str, bs))
216                + f"; n = {n_matched}"
217            ),
218            **kw,
219        )
220        class_scatter(
221            fig_b,
222            x[p2[a]],
223            y_clst[p2[a]],
224            rescale=False,
225            palette=[palt_clst[b] for b in np.unique(y_clst[p2[a]])],
226        )
227        fig_match = bk.figure(title=f"Intersection; n = {n_inter}", **kw)
228        class_scatter(
229            fig_match,
230            x[p1[a] & p2[a]],
231            y_clst[p1[a] & p2[a]],
232            rescale=False,
233            palette=[palt_clst[b] for b in np.unique(y_clst[p1[a] & p2[a]])],
234        )
235        y_diff = p3[a] + 2 * p4[a]
236        fig_diff = bk.figure(
237            title=(
238                f"Symmetric difference; n = {n_miss + n_exc}\n"
239                f"Misses (red) = {n_miss}; excess (blue) = {n_exc}"
240            ),
241            **kw,
242        )
243        class_scatter(
244            fig_diff,
245            x[y_diff > 0],
246            y_diff[y_diff > 0],
247            palette=["#ff0000", "#0000ff"],
248            rescale=False,
249        )
250        make_same_xy_range(fig_a, fig_b, fig_match, fig_diff)
251        figures.append([fig_a, fig_b, fig_match, fig_diff])
252
253    return bkl.grid(figures)  # type: ignore
254
255
256def make_same_xy_range(*args: bk.figure) -> None:
257    """
258    Makes sure all figures share the same `x_range` and `y_range`. As a result,
259    if a figure is zoomed in or dragged, all others will be too and by the same
260    amount.
261    """
262    for f in args[1:]:
263        f.x_range, f.y_range = args[0].x_range, args[0].y_range
BK_PALETTE_FUNCTIONS = {'cividis': <function cividis>, 'gray': <function gray>, 'grey': <function grey>, 'inferno': <function inferno>, 'magma': <function magma>, 'viridis': <function viridis>}
def class_scatter( figure: bokeh.plotting._figure.figure, x: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], palette: tuple[str, ...] | list[str] | str | None = None, size: float = 3, rescale: bool = True, axis_visible: bool = False, grid_visible: bool = True, outliers: bool = True) -> None:
 34def class_scatter(
 35    figure: bk.figure,
 36    x: ArrayLike,
 37    y: ArrayLike,
 38    palette: bkp.Palette | list[str] | str | None = None,
 39    size: float = 3,
 40    rescale: bool = True,
 41    axis_visible: bool = False,
 42    grid_visible: bool = True,
 43    outliers: bool = True,
 44) -> None:
 45    """
 46    Scatter plot where each class has a different color. Points in negative
 47    classes (those for which the `y` value is strictly less than 0), called
 48    *outliers* here, are all plotted black.
 49
 50    Example:
 51
 52        ![Example 1](../docs/imgs/class_scatter.png)
 53
 54        (this example does't have outliers but I'm sure you can use your
 55        imagination)
 56
 57    Warning:
 58        This method hides the figure's axis and grid lines.
 59
 60    Args:
 61        figure (bk.figure):
 62        x (np.ndarray): A `(N, 2)` array
 63        y (np.ndarray): A `(N,)` int array. Each unique value corresponds to a
 64            class
 65        palette (Palette | list[str] | str | None, optional): Either a
 66            palette object (see
 67            https://docs.bokeh.org/en/latest/docs/reference/palettes.html#bokeh-palettes
 68            ),
 69            a list of HTML colors (at least as many as the number of classes),
 70            or a name in `lcc.plotting.BK_PALETTE_FUNCTIONS`.
 71        size (float, optional): Dot size. The outlier's dot size will be half
 72            that.
 73        rescale (bool, optional): Whether to rescale the `x` values to $[0, 1]$.
 74        outliers (bool, optional): Whether to plot the outliers (those samples
 75            with a label < 0).
 76
 77    Raises:
 78        `ValueError` if the palette is unknown
 79    """
 80    x, y = to_array(x), to_int_array(y)
 81    if rescale:
 82        x = RobustScaler().fit_transform(x)
 83    assert isinstance(x, np.ndarray)  # for typechecking
 84    n_classes = min(len(np.unique(y[y >= 0])), 256)
 85    if palette is None:
 86        palette = bkp.viridis(n_classes)
 87    if isinstance(palette, str):
 88        if palette not in BK_PALETTE_FUNCTIONS:
 89            raise ValueError(f"Unknown palette '{palette}'")
 90        palette = BK_PALETTE_FUNCTIONS[palette](n_classes)
 91    for i, j in enumerate(np.unique(y[y >= 0])[:n_classes]):
 92        if not (y == j).any():
 93            continue
 94        a = x[y == j]
 95        assert isinstance(a, np.ndarray)  # for typechecking...
 96        figure.scatter(
 97            a[:, 0],
 98            a[:, 1],
 99            color=palette[i],
100            line_width=0,
101            size=size,
102        )
103    if (y < 0).any() and outliers:
104        a = x[y < 0]
105        assert isinstance(a, np.ndarray)  # for typechecking...
106        figure.scatter(
107            a[:, 0],
108            a[:, 1],
109            color="black",
110            line_width=0,
111            size=size / 2,
112        )
113    figure.axis.visible = axis_visible
114    figure.xgrid.visible = figure.ygrid.visible = grid_visible

Scatter plot where each class has a different color. Points in negative classes (those for which the y value is strictly less than 0), called outliers here, are all plotted black.

Example:

Example 1

(this example does't have outliers but I'm sure you can use your imagination)

Warning:

This method hides the figure's axis and grid lines.

Arguments:
  • figure (bk.figure):
  • x (np.ndarray): A (N, 2) array
  • y (np.ndarray): A (N,) int array. Each unique value corresponds to a class
  • palette (Palette | list[str] | str | None, optional): Either a palette object (see https://docs.bokeh.org/en/latest/docs/reference/palettes.html#bokeh-palettes ), a list of HTML colors (at least as many as the number of classes), or a name in lcc.plotting.BK_PALETTE_FUNCTIONS.
  • size (float, optional): Dot size. The outlier's dot size will be half that.
  • rescale (bool, optional): Whether to rescale the x values to $[0, 1]$.
  • outliers (bool, optional): Whether to plot the outliers (those samples with a label < 0).
Raises:
  • ValueError if the palette is unknown
def export_png(obj: Any, filename: str | pathlib.Path) -> pathlib.Path:
117def export_png(obj: Any, filename: str | Path) -> Path:
118    """
119    A replacement for `bokeh.io.export_png` which can sometimes be a bit buggy.
120    Instanciates its own Firefox webdriver. A bit slower but more reliable.
121
122    If Selenium is not installed, or if the Firefox webdriver is not installed,
123    or if any other error occurs, this method will **silently** fall back to
124    the default bokeh implementation.
125    """
126    from bokeh.io import export_png as _export_png
127
128    webdriver: Any = None
129    try:
130        from selenium.webdriver import Firefox, FirefoxOptions
131
132        opts = FirefoxOptions()
133        opts.add_argument("--headless")
134        webdriver = Firefox(options=opts)
135        _export_png(obj, filename=str(filename), webdriver=webdriver)
136    except Exception as e:
137        if isinstance(e, ModuleNotFoundError):
138            logging.error(
139                "Selenium is not installed. Falling back to default bokeh "
140                "implementation"
141            )
142        else:
143            logging.error(
144                f"Failed to export PNG using explicit Selenium driver: {e}\n"
145                "Falling back to default bokeh implementation"
146            )
147        _export_png(obj, filename=str(filename))
148    finally:
149        if webdriver is not None:
150            webdriver.close()
151    return Path(filename)

A replacement for bokeh.io.export_png which can sometimes be a bit buggy. Instanciates its own Firefox webdriver. A bit slower but more reliable.

If Selenium is not installed, or if the Firefox webdriver is not installed, or if any other error occurs, this method will silently fall back to the default bokeh implementation.

def class_matching_plot( x: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_true: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], y_clst: Union[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]]], matching: dict[int, set[int]] | dict[str, set[int]] | dict[int, set[str]] | dict[str, set[str]], size: int = 400) -> bokeh.models.layouts.GridBox:
154def class_matching_plot(
155    x: ArrayLike,
156    y_true: ArrayLike,
157    y_clst: ArrayLike,
158    matching: Matching,
159    size: int = 400,
160) -> bkm.GridBox:
161    """
162    Given a dataset `x` and two labellings `y_true` and `y_clst`, this method
163    makes a scatter plot detailling the situation. Labels in `y_true` are
164    considered to be
165    ground truth.
166
167    Example:
168        ![Example 1](../docs/imgs/class_matching_plot.png)
169
170    Warning:
171        The array `x` is always rescaled to fit in the $[0, 1]$ range.
172
173    Args:
174        x: (ArrayLike): A `(N, 2)` array
175        y_true (ArrayLike): A `(N,)` integer array with
176            values in $\\\\{ 0, 1, ..., c_a - 1 \\\\}$ for some $c_a > 0$.
177        y_clst (ArrayLike): A `(N,)` integer array with
178            values in $\\\\{ 0, 1, ..., c_b - 1 \\\\}$ for some $c_b > 0$.
179        matching (Matching): Matching between
180            the labels of `y_true` and the labels of `y_clst`. If some keys are
181            strings, they must be convertible to ints. Probably generated from
182            `lcc.correction.class_otm_matching`.
183        size (int, optional): The size of each scatter plot. Defaults to 400.
184    """
185    x, y_true, y_clst = to_array(x), to_int_array(y_true), to_int_array(y_clst)
186    x = RobustScaler().fit_transform(x)
187    assert isinstance(x, np.ndarray)  # for typechecking
188    matching = to_int_matching(matching)
189    p1, p2, p3, p4 = otm_matching_predicates(y_true, y_clst, matching)
190    n_true, n_matched = p1.sum(axis=1), p2.sum(axis=1)
191    n_inter = (p1 & p2).sum(axis=1)
192    n_miss, n_exc = p3.sum(axis=1), p4.sum(axis=1)
193
194    palt_true = bkp.viridis(len(np.unique(y_true)))
195    palt_clst = bkp.viridis(len(np.unique(y_clst)))
196    figures = []
197    kw = {"width": size, "height": size}
198    for a, bs in matching.items():
199        n_true, n_matched = p1[a].sum(), p2[a].sum()
200        n_inter = (p1[a] & p2[a]).sum()
201        n_miss, n_exc = p3[a].sum(), p4[a].sum()
202        fig_a = bk.figure(title=f"Ground truth, class {a}; n = {n_true}", **kw)
203        class_scatter(
204            fig_a,
205            x[p1[a]],
206            y_true[p1[a]],
207            rescale=False,
208            palette=[palt_true[a]],
209        )
210        if n_matched == 0:
211            figures.append([fig_a, None, None, None])
212            continue
213        fig_b = bk.figure(
214            title=(
215                f"{len(bs)} matched classes"
216                + ", ".join(map(str, bs))
217                + f"; n = {n_matched}"
218            ),
219            **kw,
220        )
221        class_scatter(
222            fig_b,
223            x[p2[a]],
224            y_clst[p2[a]],
225            rescale=False,
226            palette=[palt_clst[b] for b in np.unique(y_clst[p2[a]])],
227        )
228        fig_match = bk.figure(title=f"Intersection; n = {n_inter}", **kw)
229        class_scatter(
230            fig_match,
231            x[p1[a] & p2[a]],
232            y_clst[p1[a] & p2[a]],
233            rescale=False,
234            palette=[palt_clst[b] for b in np.unique(y_clst[p1[a] & p2[a]])],
235        )
236        y_diff = p3[a] + 2 * p4[a]
237        fig_diff = bk.figure(
238            title=(
239                f"Symmetric difference; n = {n_miss + n_exc}\n"
240                f"Misses (red) = {n_miss}; excess (blue) = {n_exc}"
241            ),
242            **kw,
243        )
244        class_scatter(
245            fig_diff,
246            x[y_diff > 0],
247            y_diff[y_diff > 0],
248            palette=["#ff0000", "#0000ff"],
249            rescale=False,
250        )
251        make_same_xy_range(fig_a, fig_b, fig_match, fig_diff)
252        figures.append([fig_a, fig_b, fig_match, fig_diff])
253
254    return bkl.grid(figures)  # type: ignore

Given a dataset x and two labellings y_true and y_clst, this method makes a scatter plot detailling the situation. Labels in y_true are considered to be ground truth.

Example:

Example 1

Warning:

The array x is always rescaled to fit in the $[0, 1]$ range.

Arguments:
  • x: (ArrayLike): A (N, 2) array
  • y_true (ArrayLike): A (N,) integer array with values in $\{ 0, 1, ..., c_a - 1 \}$ for some $c_a > 0$.
  • y_clst (ArrayLike): A (N,) integer array with values in $\{ 0, 1, ..., c_b - 1 \}$ for some $c_b > 0$.
  • matching (Matching): Matching between the labels of y_true and the labels of y_clst. If some keys are strings, they must be convertible to ints. Probably generated from lcc.correction.class_otm_matching.
  • size (int, optional): The size of each scatter plot. Defaults to 400.
def make_same_xy_range(*args: bokeh.plotting._figure.figure) -> None:
257def make_same_xy_range(*args: bk.figure) -> None:
258    """
259    Makes sure all figures share the same `x_range` and `y_range`. As a result,
260    if a figure is zoomed in or dragged, all others will be too and by the same
261    amount.
262    """
263    for f in args[1:]:
264        f.x_range, f.y_range = args[0].x_range, args[0].y_range

Makes sure all figures share the same x_range and y_range. As a result, if a figure is zoomed in or dragged, all others will be too and by the same amount.