# SPDX-FileCopyrightText: Copyright © 2025 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Helpers to evaluate scikit-learn metrics at arbitrary thresholds."""
import typing
import fairlearn.metrics
import numpy
import sklearn.metrics
UtilityMetricsType: typing.TypeAlias = typing.Literal[
"fpr", # False Positive Rate (minimize, range: [0, 1])
"tpr", # True Positive Rate (maximize, range: [0, 1])
"tnr", # True Negative Rate (minimize, range: [0, 1])
"fnr", # False Negative Rate (minimize, range: [0, 1])
"roc_auc", # Area Under the Curve for Receiver Operating Characteristic (maximize, range: [0, 1])
"prec", # Precision (maximize, range: [0, 1])
"rec", # Recall (maximize, range: [0, 1])
"avg_prec", # Average precision for Precision Recall Curve (maximize, range: [0, 1])
"f1", # F1 Score (maximize, range: [0, 1])
"acc", # Accuracy (maximize, range: [0, 1])
"bal_acc", # Balanced Accuracy (maximize, range: [0, 1])
]
"""Supported utility metrics type for pareto front estimates."""
FairnessMetricsType: typing.TypeAlias = typing.Literal[
"dpd", # Demographic Parity Difference (minimize, range: [0, 1])
"dpr", # Demographic Parity Ratio (maximize, range: [0, 1])
"eod", # Equalized Odds Difference (minimize, range: [0, 1])
"eor", # Equalized Odds Ratio (maximize, range: [0, 1])
]
"""Supported fairness metrics type for pareto front estimates."""
MinMaxFairnessMetricsType: typing.TypeAlias = typing.Literal[
"minmaxd", # Min-Max (absolute) difference: max(util) - min(util)
"minmaxr", # Min-Max ratio: min(util) / max(util)
]
"""Supported min-max fairness metrics type for pareto front estimates."""
[docs]
def parse_metric(
name: str,
) -> (
UtilityMetricsType
| tuple[FairnessMetricsType, str]
| tuple[MinMaxFairnessMetricsType, UtilityMetricsType, str]
):
"""Parse and validate a string supposed to carry a metric name.
Valid metric names are the ones listed in :py:type:`UtilityMetricsType`,
:py:type:`FairnessMetricsType` (followed by a "+<attr>"), or
:py:type:`MinMaxFairnessMetricsType` (followed by a "+<util>+<attr>"), where
``<attr>`` corresponds to the protected attribute being measured by the fairness
metric, and ``<util>`` corresponds to the :py:type:`UtilityMetricsType` to be used
to measure min-max fairness difference or ratios.
Parameters
----------
name
The string to be validated.
Returns
-------
The parsed metric.
Raises
------
ValueError
If the metric expressed in ``name`` is invalid.
"""
parts = name.split("+", 2)
if parts[0] in typing.get_args(UtilityMetricsType):
return typing.cast(UtilityMetricsType, parts[0])
if parts[0] in typing.get_args(FairnessMetricsType):
if len(parts) != 2:
raise ValueError(
f"fairness metric should be set like `{parts[0]}+<attr>` "
f"(`{name}` is invalid)"
)
return (typing.cast(FairnessMetricsType, parts[0]), parts[1])
if parts[0] in typing.get_args(MinMaxFairnessMetricsType):
if len(parts) != 3:
raise ValueError(
f"min-max fairness metric should be set like `{parts[0]}+<util>+<attr>` "
f"(`{name}` is invalid)"
)
if parts[1] not in typing.get_args(UtilityMetricsType):
raise ValueError(f"invalid utility metric name `{parts[1]}` at `{name}`")
return (
typing.cast(MinMaxFairnessMetricsType, parts[0]),
typing.cast(UtilityMetricsType, parts[1]),
parts[2],
)
raise ValueError(f"Invalid metric specification: `{name!r}`")
[docs]
def should_minimize(metric: str) -> bool:
"""For a given metric, tells if it should be minimized or maximized.
Currently, "fpr", in the utility side, "minmaxd" on the min-max fairness metrics, or
any other fairness metric should be minimized. All others should be maximized.
Parameters
----------
metric
Metric name.
Returns
-------
``True``, if the metric should be minimized (instead of maximized). ``False``
otherwise.
Raises
------
ValueError
If the metric is invalid.
"""
pm = parse_metric(metric)
return (
pm in ("fpr", "fnr")
or (isinstance(pm, tuple) and len(pm) == 2 and pm[0] in ("dpd", "eod"))
or (isinstance(pm, tuple) and len(pm) == 3 and pm[0] == "minmaxd")
)
[docs]
def supported_metrics() -> list[str]:
"""Generate a comma-separated list of supported metrics.
Returns
-------
A comma-separated list of supported metrics.
"""
utility = typing.get_args(UtilityMetricsType)
fairness = [f"{k}+<attr>" for k in typing.get_args(FairnessMetricsType)]
minmax = [f"{k}+<util>+<attr>" for k in typing.get_args(MinMaxFairnessMetricsType)]
return list(utility) + fairness + minmax
[docs]
def calculate_metric(
metric: str,
y_true: typing.Sequence[int],
y_score: typing.Sequence[float],
thresholds: typing.Sequence[float],
sensitive_attributes: typing.Mapping[str, typing.Sequence[int | str]] | None = None,
) -> list[float]:
"""Entry-point function to calculate arbitrary (supported) metrics.
This function works as an entry-point to the metric calculation submodule. It can
calculate arbirary (supported) metrics provided input information for a system,
consisting of ground-truth, scores, thresholds and (optionally) sensitive features.
Parameters
----------
metric
The metric to calculate.
y_true
True binary labels (0 or 1).
y_score
Predicted continuous scores or probabilities.
thresholds
Threshold values at which to binarize ``y_score`` (:math:`score >= threshold`
implies sample is classified as positive).
sensitive_attributes
Group membership for each sample, according to protected attribute. Only
required if ``metric`` is a fairness metric. Each entry in the input dictionary
should match the order of samples in ``y_true`` and ``y_score``. When ``metric``
refers to a particular sensitive attribute, it should be a key in this
dictionary.
Returns
-------
The metric over all considered thresholds.
Raises
------
ValueError
In case of unknown metrics.
"""
parsed_metric = parse_metric(metric)
if parsed_metric in typing.get_args(UtilityMetricsType):
# simple closure to avoid repeatitive for loops with the same config
def _for_all_t(f):
y_score_arr = numpy.asarray(y_score, dtype=float)
return numpy.nan_to_num(
[f(y_true, y_score_arr >= t) for t in thresholds]
).tolist()
match metric:
case "fpr":
return _for_all_t(fairlearn.metrics.false_positive_rate)
case "tpr":
return _for_all_t(fairlearn.metrics.true_positive_rate)
case "fnr":
return _for_all_t(fairlearn.metrics.false_negative_rate)
case "tnr":
return _for_all_t(fairlearn.metrics.true_negative_rate)
case "acc":
return _for_all_t(sklearn.metrics.accuracy_score)
case "bal_acc":
return _for_all_t(sklearn.metrics.balanced_accuracy_score)
case "prec":
return _for_all_t(sklearn.metrics.precision_score)
case "rec":
return _for_all_t(sklearn.metrics.recall_score)
case "f1":
return _for_all_t(sklearn.metrics.f1_score)
case "roc_auc":
# there is only 1 roc-auc, where there are multiple fpr or tpr points --
# so we repeat the roc-auc as many times as there are thresholds to keep
# consistence.
val = float(sklearn.metrics.roc_auc_score(y_true, y_score))
return len(thresholds) * [val]
case "avg_prec":
# there is only 1 average precision, where there are multiple fpr or tpr
# points -- so we repeat the average precision as many times as there are
# thresholds to keep consistence.
val = float(sklearn.metrics.average_precision_score(y_true, y_score))
return len(thresholds) * [val]
elif (
isinstance(parsed_metric, tuple)
and len(parsed_metric) == 2
and parsed_metric[0] in typing.get_args(FairnessMetricsType)
):
assert sensitive_attributes is not None
assert parsed_metric[1] in sensitive_attributes
# simple closure to avoid repeatitive for loops with the same config
def _for_all_t(f):
y_score_arr = numpy.asarray(y_score, dtype=float)
return numpy.nan_to_num(
[
f(
y_true,
y_score_arr >= t,
sensitive_features=sensitive_attributes[parsed_metric[1]],
)
for t in thresholds
]
).tolist()
match parsed_metric[0]:
case "dpd":
return _for_all_t(fairlearn.metrics.demographic_parity_difference)
case "dpr":
return _for_all_t(fairlearn.metrics.demographic_parity_ratio)
case "eod":
return _for_all_t(fairlearn.metrics.equalized_odds_difference)
case "eor":
return _for_all_t(fairlearn.metrics.equalized_odds_ratio)
elif (
isinstance(parsed_metric, tuple)
and len(parsed_metric) == 3
and parsed_metric[0] in typing.get_args(MinMaxFairnessMetricsType)
):
assert sensitive_attributes is not None
assert parsed_metric[2] in sensitive_attributes
# simple closures to avoid repeatitive for loops with the same config
def _for_all_t(f):
y_score_arr = numpy.asarray(y_score, dtype=float)
return numpy.nan_to_num(
[
f(
y_true,
y_score_arr >= t,
sensitive_features=sensitive_attributes[parsed_metric[2]],
)
for t in thresholds
]
).tolist()
def _make_derived(f):
return fairlearn.metrics.make_derived_metric(
metric=f,
transform="difference" if parsed_metric[0] == "minmaxd" else "ratio",
)
match parsed_metric[1]:
case "fpr":
return _for_all_t(_make_derived(fairlearn.metrics.false_positive_rate))
case "tpr":
return _for_all_t(_make_derived(fairlearn.metrics.true_positive_rate))
case "fnr":
return _for_all_t(_make_derived(fairlearn.metrics.false_negative_rate))
case "tnr":
return _for_all_t(_make_derived(fairlearn.metrics.true_negative_rate))
case "acc":
return _for_all_t(_make_derived(sklearn.metrics.accuracy_score))
case "bal_acc":
return _for_all_t(
_make_derived(sklearn.metrics.balanced_accuracy_score)
)
case "prec":
return _for_all_t(_make_derived(sklearn.metrics.precision_score))
case "rec":
return _for_all_t(_make_derived(sklearn.metrics.recall_score))
case "f1":
return _for_all_t(_make_derived(sklearn.metrics.f1_score))
case "roc_auc":
# there is only 1 roc-auc, where there are multiple fpr or tpr points --
# so we repeat the roc-auc as many times as there are thresholds to keep
# consistence.
val = float(
_make_derived(sklearn.metrics.roc_auc_score)(
y_true,
y_score,
sensitive_features=sensitive_attributes[parsed_metric[2]],
)
)
return len(thresholds) * [val]
case "avg_prec":
# there is only 1 average precision, where there are multiple fpr or tpr
# points -- so we repeat the average precision as many times as there are
# thresholds to keep consistence.
val = float(
_make_derived(sklearn.metrics.average_precision_score)(
y_true,
y_score,
sensitive_features=sensitive_attributes[parsed_metric[2]],
)
)
return len(thresholds) * [val]
# this should not occur, as metric is parsed from start
raise ValueError(f"Invalid metric specification: `{metric!r}`")