Source code for fairical.plot

# SPDX-FileCopyrightText: Copyright © 2025 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Plotting utilities."""

import itertools
import typing

import matplotlib.axes
import matplotlib.colors
import matplotlib.figure
import matplotlib.pyplot
import numpy

from .metrics import FairnessMetricsType, MinMaxFairnessMetricsType, UtilityMetricsType
from .solutions import Solutions
from .utils import IndicatorType, extend_indicators, parse_indicator


[docs] def radar_chart( indicators: dict[str, dict[IndicatorType, float]], axes: dict[IndicatorType | str, str] = { "relative-onvg": r"$\widehat{ONVG}$", "onvgr": r"$ONVGR$", "ud": r"$UD$", "as": r"$AS$", "hv": r"$HV$", }, ) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: """Generate radar chart for all systems under comparison. This method generates radar chart given performance indicator values in comparison of systems. It requires the presence of ``complement-ud`` and ``relative-onvg`` on ``indicators``. Parameters ---------- indicators Indicators organized in a single dictionary where keys represent system labels, and values, dictionaries with *at least* the same keys as listed in ``axes_keys``. axes A dictionary containing the indicator keys that will be drawn on the radar chart, and corresponding labels associated with each of those axes. You can use LaTeX symbols and notations on the values of the dictionary. title The plot title. **kwargs Additional keyword arguments for updating chart properties. Supported options: - linewidth: Line width - linestyle: Line style Returns ------- A tuple containing both the matplotlib figure and axes used to create the radar chart. """ _hatch_list = ["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*"] _hatch_cycle = itertools.cycle(_hatch_list) ndim = len(axes) # validate inputs assert ndim >= 3 axes_ind: dict[IndicatorType, str] = { parse_indicator(k): v for k, v in axes.items() } extend_indicators(list(indicators.values()), list(axes_ind.keys())) values = numpy.array( [[v[k] for k in axes_ind.keys()] for v in indicators.values()], dtype=float ) # concatenate the first column to last for looping values = numpy.column_stack((values, values[:, 0])) # draw plot angles = numpy.linspace(0, 2 * numpy.pi, values.shape[1], endpoint=True) fig, ax = matplotlib.pyplot.subplots(subplot_kw=dict(polar=True)) labels = list(indicators.keys()) areas = [k["area"] for k in indicators.values()] for i, value in enumerate(values): # draws the line around each system entry (line,) = ax.plot( angles, value, label=f"{labels[i]} ($\\triangle={areas[i]:.2f})$" ) # fills the polygon ax.fill(angles, value, color=line.get_color(), alpha=0.25) # draws the hatching over the radar surface ax.fill( angles, value, color="none", edgecolor=line.get_color(), alpha=0.40, hatch=next(_hatch_cycle), ) ax.set_xticks(angles[:-1]) ax.set_xticklabels(axes.values()) ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0]) ax.set_yticklabels(["0.2", "0.4", "0.6", "0.8", "1.0"]) ax.tick_params(axis="both", which="major") ax.set_ylim(0, 1) ax.legend( loc="lower center", bbox_to_anchor=(0.5, -0.2), # “where that corner goes in axes coords” bbox_transform=ax.transAxes, # ncol=len(indicators), # spread entries horizontally ) fig.tight_layout() return fig, ax
_AXES_LABELS_UTILITY: dict[UtilityMetricsType, str] = { "fpr": "Utility (FPR)", "tpr": "Utility (TPR)", "fnr": "Utility (FNR)", "tnr": "Utility (TNR)", "roc_auc": "Utility (AUROC)", "prec": "Utility (Precision)", "rec": "Utility (Recall)", "avg_prec": "Utility (Avg.Precision)", "f1": "Utility (F1-score)", "acc": "Utility (Accuracy)", "bal_acc": "Utility (Accuracy)", } """Labels for pareto plots (utility).""" _AXES_LABELS_FAIRNESS: dict[FairnessMetricsType, str] = { "dpd": "Fairness ($\\text{DPD}_\\text{%s})$", "dpr": "Fairness ($\\text{DPR}_\\text{%s})$", "eod": "Fairness ($\\text{EOD}_\\text{%s})$", "eor": "Fairness ($\\text{EOR}_\\text{%s})$", } """Labels for pareto plots (fairness).""" _AXES_LABELS_MINMAX_FAIRNESS: dict[MinMaxFairnessMetricsType, str] = { "minmaxd": "Min-Max Fairness (Diff. $\\text{%s}_\\text{%s})$", "minmaxr": "Min-Max Fairness (Ratio $\\text{%s}_\\text{%s})$", } """Labels for pareto plots (min-max fairness)."""
[docs] def pareto_plot( solutions: dict[str, tuple[Solutions, Solutions]], axes_labels: dict[str, str] = {}, alpha: float = 0.2, hide_ds: bool = False, ) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: """Generate pareto plot for all systems under comparison. This method generates pareto plot given solutions of systems in comparison. Parameters ---------- solutions A dictionary where keys represent system names (that will be used as labels), and values are tuples with non-dominated (nds) and dominated solutions (ds) respectively. axes_labels If specified, overwrites the default labels for dimensions in :py:class:`fairical.solutions.Solutions`. Should be a dictionary that maps the keys in each :py:class:`fairical.solutions.Solutions` object to a single label. If not set, then we use a default setup provided in the module. alpha Alpha blend between non-dominated (fully opaque) and dominated solutions (partly transparent). hide_ds If true, hide the ds points for a-priori data from the plot. Returns ------- A tuple of lists containing both the matplotlib figures and axes used to create the pareto plot. The lists will contain 2 elements each if `separate` is True, 1 otherwise. """ _marker_list = ["o", "s", "^", "v", "<", ">", "d", "P", "X", "*", "+"] _marker_cycle = itertools.cycle(_marker_list) ndim = next(iter(solutions.values()))[0].n_metrics() # validate inputs assert 2 <= ndim <= 3 def plot_pareto_front(nds_arr, ax, line_color): if nds_arr.shape[1] == 2: # sort by the first metric (x-axis) to get a sensible path order = numpy.argsort(nds_arr[:, 0]) x_sorted, y_sorted = nds_arr[order].T ax.plot( x_sorted, y_sorted, linestyle="-", color=matplotlib.colors.to_rgba(line_color, alpha=min(2 * alpha, 1.0)), marker=marker, markerfacecolor=matplotlib.colors.to_rgba(line_color, alpha=1.0), linewidth=1, label=label, ) else: # 3 dimensional case ax.plot_trisurf( *nds_arr.T, antialiased=False, shade=0, alpha=0.5, edgecolor="none", ) return ax def resolve_ax_labels(solutions, axes_labels): # resolve which axes labels to use use_axes_labels: list[str] = [] for k in next(iter(solutions.values()))[0].keys(): if k in axes_labels: use_axes_labels.append(axes_labels[k]) else: parts = k.split("+", 2) if parts[0] in typing.get_args(UtilityMetricsType): use_axes_labels.append( _AXES_LABELS_UTILITY[typing.cast(UtilityMetricsType, parts[0])] ) elif parts[0] in typing.get_args(FairnessMetricsType): use_axes_labels.append( _AXES_LABELS_FAIRNESS[ typing.cast(FairnessMetricsType, parts[0]) ] ) if "%s" in use_axes_labels[-1]: use_axes_labels[-1] = use_axes_labels[-1] % parts[1] elif parts[0] in typing.get_args(MinMaxFairnessMetricsType): use_axes_labels.append( _AXES_LABELS_MINMAX_FAIRNESS[ typing.cast(MinMaxFairnessMetricsType, parts[0]) ] ) if "%s" in use_axes_labels[-1]: use_axes_labels[-1] = use_axes_labels[-1] % tuple(parts[1:]) return use_axes_labels fig = matplotlib.pyplot.figure() ax = fig.add_subplot(projection="3d") if ndim == 3 else fig.add_subplot() for label, (nds, ds) in solutions.items(): marker = next(_marker_cycle) # plot non-dominated points with no transparency nds_arr = numpy.asarray(nds) pc_nds = ax.scatter( *nds_arr.T, marker=marker, ) fc = pc_nds.get_facecolors() the_color = fc[0] if len(fc) else pc_nds.get_edgecolors()[0] if not hide_ds: # plot dominated points using the same marker and color ds_arr = numpy.asarray(ds) ax.scatter( *ds_arr.T, color=the_color, alpha=alpha, marker=marker, ) plot_pareto_front(nds_arr, ax, the_color) use_axes_labels = resolve_ax_labels(solutions, axes_labels) ax.set_xlabel(use_axes_labels[0]) ax.set_ylabel(use_axes_labels[1]) if ndim == 3: ax.set_zlabel(use_axes_labels[2]) ax.grid() ax.legend( loc="lower center", bbox_to_anchor=(0.5, -0.3), # “where that corner goes in axes coords” bbox_transform=ax.transAxes, ) fig.tight_layout() return fig, ax