Ranking Losses (rax.*_loss)

Implementations of common ranking losses in JAX.

A ranking loss is a differentiable function that expresses the cost of a ranking induced by item scores compared to a ranking induced from relevance labels. Rax provides a number of ranking losses as JAX functions that are implemented according to the LossFn interface.

Loss functions are designed to operate on the last dimension of its inputs. The leading dimensions are considered batch dimensions. To compute per-list losses, for example to apply per-list weighting or for distributed computing of losses across devices, please use standard JAX transformations such as jax.vmap() or jax.pmap().

Standalone usage:

>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([1., 0., 0.])
>>> rax.softmax_loss(scores, labels)
DeviceArray(1.4076059, dtype=float32)

Usage with a batch of data and a mask to indicate valid items.

>>> scores = jnp.array([[2., 1., 0.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[1., 0., 0.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> rax.pairwise_hinge_loss(
...     scores, labels, where=where, reduce_fn=jnp.mean)
DeviceArray(0.16666667, dtype=float32)

To compute gradients of each loss function, please use standard JAX transformations such as jax.grad() or jax.value_and_grad():

>>> scores = jnp.asarray([[0., 1., 3.], [1., 2., 0.]])
>>> labels = jnp.asarray([[0., 0., 1.], [1., 0., 0.]])
>>> jax.grad(rax.softmax_loss)(scores, labels, reduce_fn=jnp.mean)
DeviceArray([[ 0.02100503,  0.0570976 , -0.07810265],
             [-0.37763578,  0.33262047,  0.04501529]], dtype=float32)

pointwise_mse_loss(scores, labels, *[, ...])

Mean squared error loss.

pointwise_sigmoid_loss(scores, labels, *[, ...])

Sigmoid cross entropy loss.

pairwise_hinge_loss(scores, labels, *[, ...])

Pairwise hinge loss.

pairwise_logistic_loss(scores, labels, *[, ...])

Pairwise logistic loss.

pairwise_mse_loss(scores, labels, *[, ...])

Pairwise mean squared error loss.

softmax_loss(scores, labels, *[, where, ...])

Softmax loss.

rax.pointwise_mse_loss(scores, labels, *, where=None, weights=None, reduce_fn=<function sum>)

Mean squared error loss.

Definition:

\[\ell(s, y) = \sum_i (y_i - s_i)^2 \]
Parameters
  • scores (ndarray) – A [..., list_size]-ndarray, indicating the score of each item.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.

  • weights (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating the weight for each item.

  • reduce_fn (ReduceFn) – An optional function that reduces the loss values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The mean squared error loss.

rax.pointwise_sigmoid_loss(scores, labels, *, where=None, weights=None, reduce_fn=<function sum>)

Sigmoid cross entropy loss.

Definition:

\[\ell(s, y) = \sum_i y_i * -log(sigmoid(s_i)) + (1 - y_i) * -log(1 - sigmoid(s_i)) \]

This loss converts graded relevance to binary relevance by considering items with label >= 1 as relevant and items with label < 1 as non-relevant.

Parameters
  • scores (ndarray) – A [..., list_size]-ndarray, indicating the score of each item.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional […, list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.

  • weights (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating the weight for each item.

  • reduce_fn (ReduceFn) – An optional function that reduces the loss values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The sigmoid cross entropy loss.

rax.pairwise_hinge_loss(scores, labels, *, where=None, weights=None, reduce_fn=<function sum>)

Pairwise hinge loss.

Definition:

\[\ell(s, y) = \sum_i \sum_j I[y_i > y_j] \max(0, 1 - (s_i - s_j)) \]
Parameters
  • scores (ndarray) – A [..., list_size]-ndarray, indicating the score of each item.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.

  • weights (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating the weight for each item.

  • reduce_fn (ReduceFn) – An optional function that reduces the loss values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The pairwise hinge loss.

rax.pairwise_logistic_loss(scores, labels, *, where=None, weights=None, reduce_fn=<function sum>)

Pairwise logistic loss.

Definition [Burges et al., 2005]:

\[\ell(s, y) = \sum_i \sum_j I[y_i > y_j] \log(1 + \exp(-(s_i - s_j))) \]
Parameters
  • scores (ndarray) – A [..., list_size]-ndarray, indicating the score of each item.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.

  • weights (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating the weight for each item.

  • reduce_fn (ReduceFn) – An optional function that reduces the loss values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The pairwise logistic loss.

rax.pairwise_mse_loss(scores, labels, *, where=None, weights=None, reduce_fn=<function sum>)

Pairwise mean squared error loss.

Definition:

\[\ell(s, y) = \sum_i \sum_j ((y_i - y_j) - (s_i - s_j))^2 \]
Parameters
  • scores (ndarray) – A [..., list_size]-ndarray, indicating the score of each item.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.

  • weights (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating the weight for each item.

  • reduce_fn (ReduceFn) – An optional function that reduces the loss values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The pairwise mean squared error loss.

rax.softmax_loss(scores, labels, *, where=None, weights=None, label_fn=<function <lambda>>, reduce_fn=<function sum>)

Softmax loss.

Definition:

\[\ell(s, y) = \sum_i y_i \log \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]
Parameters
  • scores (ndarray) – A [..., list_size]-ndarray, indicating the score of each item.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.

  • weights (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating the weight for each item.

  • label_fn (Callable[..., ndarray]) – A label function that maps labels to probabilities. Default keeps labels as-is.

  • reduce_fn (Optional[ReduceFn]) – An optional function that reduces the loss values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The softmax loss.

Ranking Metrics (rax.*_metric)

Implementations of common ranking metrics in JAX.

A ranking metric expresses how well a ranking induced by item scores matches a ranking induced from relevance labels. Rax provides a number of ranking metrics as JAX functions that are implemented according to the MetricFn interface.

Metric functions are designed to operate on the last dimension of its inputs. The leading dimensions are considered batch dimensions. To compute per-list metrics, for example to apply per-list weighting or for distributed computing of metrics across devices, please use standard JAX transformations such as jax.vmap() or jax.pmap().

Standalone usage of a metric:

>>> import jax
>>> import rax
>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([2., 0., 1.])
>>> rax.ndcg_metric(scores, labels)
DeviceArray(0.79670763, dtype=float32)

Usage with a batch of data and a mask to indicate valid items:

>>> scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[2., 0., 1.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> rax.ndcg_metric(scores, labels)
DeviceArray(0.8983538, dtype=float32)

Usage with jax.vmap() batching and a mask to indicate valid items:

>>> scores = jnp.array([[2., 1., 0.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[1., 0., 0.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> jax.vmap(rax.ndcg_metric)(scores, labels, where=where)
DeviceArray([1., 1.], dtype=float32)

mrr_metric(scores, labels, *[, where, topn, ...])

Mean Reciprocal Rank (MRR).

precision_metric(scores, labels, *[, where, ...])

Precision.

recall_metric(scores, labels, *[, where, ...])

Recall.

ap_metric(scores, labels, *[, where, topn, ...])

Average Precision.

dcg_metric(scores, labels, *[, where, topn, ...])

Discounted cumulative gain (DCG).

ndcg_metric(scores, labels, *[, where, ...])

Normalized discounted cumulative gain (NDCG).

rax.mrr_metric(scores, labels, *, where=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)

Mean Reciprocal Rank (MRR).

Note

This metric converts graded relevance to binary relevance by considering items with label >= 1 as relevant and items with label < 1 as non-relevant.

Definition:

\[\operatorname{mrr}(s, y) = \max_i \frac{y_i}{\operatorname{rank}(s_i)} \]

where \(\operatorname{rank}(s_i)\) indicates the rank of item \(i\) after sorting all scores \(s\) using rank_fn.

Parameters
  • scores (ndarray) –

    A [..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treated

    as unranked items.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the metric.

  • topn (Optional[int]) – An optional integer value indicating at which rank the metric cuts off. If None, no cutoff is performed.

  • key (Optional[ndarray]) – An optional PRNGKey(). If provided, any random operations in this metric will be based on this key.

  • rank_fn (RankFn) – A function that maps scores to 1-based ranks.

  • cutoff_fn (CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.

  • reduce_fn (Optional[ReduceFn]) – An optional function that reduces the metric values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The MRR metric.

rax.precision_metric(scores, labels, *, where=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)

Precision.

Note

This metric converts graded relevance to binary relevance by considering items with label >= 1 as relevant and items with label < 1 as non-relevant.

Definition:

\[\operatorname{precision@n}(s, y) = \frac{1}{n} \sum_i y_i \cdot \mathbb{I}\left[\operatorname{rank}(s_i) \leq n\right] \]

where \(\operatorname{rank}(s_i)\) indicates the rank of item \(i\) after sorting all scores \(s\) using rank_fn.

Parameters
  • scores (ndarray) –

    A [..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treated

    as unranked items.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the metric.

  • topn (Optional[int]) – An optional integer value indicating at which rank the metric cuts off. If None, no cutoff is performed.

  • key (Optional[ndarray]) – An optional PRNGKey(). If provided, any random operations in this metric will be based on this key.

  • rank_fn (RankFn) – A function that maps scores to 1-based ranks.

  • cutoff_fn (CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.

  • reduce_fn (Optional[ReduceFn]) – An optional function that reduces the metric values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The precision metric.

rax.recall_metric(scores, labels, *, where=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)

Recall.

Note

This metric converts graded relevance to binary relevance by considering items with label >= 1 as relevant and items with label < 1 as non-relevant.

Definition:

\[\operatorname{recall@n}(s, y) = \frac{1}{\sum_i y_i} \sum_i y_i \cdot \mathbb{I}\left[\operatorname{rank}(s_i) \leq n\right] \]

where \(\operatorname{rank}(s_i)\) indicates the rank of item \(i\) after sorting all scores \(s\) using rank_fn.

Parameters
  • scores (ndarray) –

    A [..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treated

    as unranked items.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the metric.

  • topn (Optional[int]) – An optional integer value indicating at which rank the metric cuts off. If None, no cutoff is performed.

  • key (Optional[ndarray]) – An optional PRNGKey(). If provided, any random operations in this metric will be based on this key.

  • rank_fn (RankFn) – A function that maps scores to 1-based ranks.

  • cutoff_fn (CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.

  • reduce_fn (Optional[ReduceFn]) – An optional function that reduces the metric values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The recall metric.

rax.ap_metric(scores, labels, *, where=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)

Average Precision.

Note

This metric converts graded relevance to binary relevance by considering items with label >= 1 as relevant and items with label < 1 as non-relevant.

Definition:

\[\operatorname{ap}(s, y) = \frac{1}{\sum_i y_i} \sum_i y_i \operatorname{precision@rank}_{s_i}(s, y) \]

where \(\operatorname{precision@rank}_{s_i}(s, y)\) indicates the precision at the rank of item \(i\).

Parameters
  • scores (ndarray) –

    A [..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treated

    as unranked items.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the metric.

  • topn (Optional[int]) – An optional integer value indicating at which rank the metric cuts off. If None, no cutoff is performed.

  • key (Optional[ndarray]) – An optional PRNGKey(). If provided, any random operations in this metric will be based on this key.

  • rank_fn (RankFn) – A function that maps scores to 1-based ranks.

  • cutoff_fn (CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.

  • reduce_fn (Optional[ReduceFn]) – An optional function that reduces the metric values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The average precision metric.

rax.dcg_metric(scores, labels, *, where=None, topn=None, weights=None, key=None, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)

Discounted cumulative gain (DCG).

Definition [Järvelin and Kekäläinen, 2002]:

\[\operatorname{dcg}(s, y) = \sum_i \operatorname{gain}(y_i) \cdot \operatorname{discount}(\operatorname{rank}(s_i)) \]

where \(\operatorname{rank}(s_i)\) indicates the 1-based rank of item \(i\) as computed by rank_fn, \(\operatorname{gain}(y)\) indicates the per-item gains as computed by gain_fn, and, \(\operatorname{discount}(r)\) indicates the per-item rank discounts as computed by discount_fn.

Parameters
  • scores (ndarray) –

    A [..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treated

    as unranked items.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the metric.

  • topn (Optional[int]) – An optional integer value indicating at which rank the metric cuts off. If None, no cutoff is performed.

  • weights (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating the per-item weights.

  • key (Optional[ndarray]) – An optional PRNGKey(). If provided, any random operations in this metric will be based on this key.

  • gain_fn (Callable[[ndarray], ndarray]) – A function that maps relevance label to gain values.

  • discount_fn (Callable[[ndarray], ndarray]) – A function that maps 1-based ranks to discount values.

  • rank_fn (RankFn) – A function that maps scores to 1-based ranks.

  • cutoff_fn (CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.

  • reduce_fn (Optional[ReduceFn]) – An optional function that reduces the metric values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The DCG metric.

rax.ndcg_metric(scores, labels, *, where=None, topn=None, weights=None, key=None, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)

Normalized discounted cumulative gain (NDCG).

Definition [Järvelin and Kekäläinen, 2002]:

\[\operatorname{ndcg}(s, y) = \operatorname{dcg}(s, y) / \operatorname{dcg}(y, y) \]

where \(\operatorname{dcg}\) is the discounted cumulative gain metric.

Parameters
  • scores (ndarray) –

    A [..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treated

    as unranked items.

  • labels (ndarray) – A [..., list_size]-ndarray, indicating the relevance label for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid for computing the metric.

  • topn (Optional[int]) – An optional integer value indicating at which rank the metric cuts off. If None, no cutoff is performed.

  • weights (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating the per-item weights.

  • key (Optional[ndarray]) – An optional PRNGKey(). If provided, any random operations in this metric will be based on this key.

  • gain_fn (Callable[[ndarray], ndarray]) – A function that maps relevance label to gain values.

  • discount_fn (Callable[[ndarray], ndarray]) – A function that maps 1-based ranks to discount values.

  • rank_fn (RankFn) – A function that maps scores to 1-based ranks.

  • cutoff_fn (CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.

  • reduce_fn (Optional[ReduceFn]) – An optional function that reduces the metric values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Return type

ndarray

Returns

The NDCG metric.

Function Transformations (rax.*_t12n)

Function transformations for ranking losses and metrics.

These function transformations can be used to transform the ranking metrics and losses. An example is approx_t12n which transforms a given ranking metric into a ranking loss by plugging in differentiable approximations to the rank and cutoff functions.

Example usage:

>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> approx_ndcg_loss_fn = rax.approx_t12n(rax.ndcg_metric)
>>> approx_ndcg_loss_fn(scores, labels)
DeviceArray(-0.71789175, dtype=float32)

approx_t12n(metric_fn[, temperature])

Transforms metric_fn into an approximate differentiable loss.

bound_t12n(metric_fn)

Transforms metric_fn into a lower-bound differentiable loss.

gumbel_t12n(loss_or_metric_fn, *[, samples, ...])

Transforms loss_or_metric_fn to operate on Gumbel-sampled scores.

rax.approx_t12n(metric_fn, temperature=1.0)

Transforms metric_fn into an approximate differentiable loss.

This transformation and uses a sigmoid approximation to compute ranks and indicators in metrics [Qin et al., 2010]. The returned approximate metric is mapped to negative values to be used as a loss.

Example usage:

>>> approx_mrr = rax.approx_t12n(rax.mrr_metric)
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> approx_mrr(scores, labels)
DeviceArray(-0.6965873, dtype=float32)

Example usage together with rax.gumbel_t12n():

>>> gumbel_approx_mrr = rax.gumbel_t12n(rax.approx_t12n(rax.mrr_metric))
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> gumbel_approx_mrr(scores, labels, key=jax.random.PRNGKey(42))
DeviceArray(-0.71880937, dtype=float32)
Parameters
  • metric_fn (MetricFn) – The metric function to convert to an approximate loss.

  • temperature (float) – The temperature parameter to use for the sigmoid approximation.

Return type

LossFn

Returns

A loss function that computes the approximate version of metric_fn.

rax.bound_t12n(metric_fn)

Transforms metric_fn into a lower-bound differentiable loss.

This transformation uses a hinge bound to compute ranks and indicators in metrics. The returned lower-bound of the metric is mapped to negative values to be used as a loss.

Example usage:

>>> bound_mrr = rax.bound_t12n(rax.mrr_metric)
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 1., 0., 1.])
>>> bound_mrr(scores, labels)
DeviceArray(-0.33333334, dtype=float32)

Example usage together with rax.gumbel_t12n():

>>> gumbel_bound_mrr = rax.gumbel_t12n(rax.bound_t12n(rax.mrr_metric))
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 1., 0., 1.])
>>> gumbel_bound_mrr(scores, labels, key=jax.random.PRNGKey(42))
DeviceArray(-0.31619418, dtype=float32)
Parameters

metric_fn (MetricFn) – The metric function to convert to a lower-bound loss.

Returns

A loss function that computes the lower-bound version of metric_fn.

rax.gumbel_t12n(loss_or_metric_fn, *, samples=8, beta=1.0)

Transforms loss_or_metric_fn to operate on Gumbel-sampled scores.

This transformation changes given loss_or_metric_fn so that it samples scores from a Gumbel distribution prior to computing the loss or metric [Bruch et al., 2020]. The returned function requires a new key keyword argument.

Example usage:

>>> loss_fn = rax.gumbel_t12n(rax.softmax_loss)
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> loss_fn(scores, labels, key=jax.random.PRNGKey(42))
DeviceArray(49.65323, dtype=float32)
>>> loss_fn(scores, labels, key=jax.random.PRNGKey(79))
DeviceArray(40.102238, dtype=float32)
Parameters
  • loss_or_metric_fn (TypeVar(LossOrMetricFn, LossFn, MetricFn)) – A Rax loss or metric function.

  • samples (int) – Number of Gumbel samples to create.

  • beta (float) – Shape of the Gumbel distribution (default 1.0).

Return type

TypeVar(LossOrMetricFn, LossFn, MetricFn)

Returns

A new function that behaves the same as loss_or_metric_fn but which requires an additional key argument that will be used to randomly sample the scores from a Gumbel distribution.

Utilities

ranks(scores, *[, where, axis, key])

Computes the ranks for given scores.

cutoff(a[, n, where, step_fn])

Computes a binary array to select the largest n values of a.

approx_ranks(scores, *[, where, key, step_fn])

Computes approximate ranks.

approx_cutoff(a[, n, where, step_fn])

Computes a binary array to select the largest n values of a.

rax.utils.ranks(scores, *, where=None, axis=- 1, key=None)

Computes the ranks for given scores.

Note that the ranks returned by this function are not differentiable due to the sort operation having no gradients.

Parameters
  • scores (ndarray) – A [..., list_size]-ndarray, indicating the score for each item.

  • where (Optional[ndarray]) – An optional [..., list_size]-ndarray, indicating which items are valid.

  • axis (int) – The axis to sort on, by default this is the last axis.

  • key (Optional[ndarray]) – An optional jax.random.PRNGKey(). If provided, ties will be broken randomly using this key. If not provided, ties will retain the order of their appearance in the scores array.

Return type

ndarray

Returns

A tensor with the same shape as scores that indicates the 1-based rank of each item.

rax.utils.cutoff(a, n=None, *, where=None, step_fn=<function <lambda>>)

Computes a binary array to select the largest n values of a.

This function computes a binary jax.numpy.ndarray that selects the n largest values of a across its last dimension.

Note that the returned indicator may select more than n items if a has ties.

Parameters
Return type

ndarray

Returns

A jax.numpy.ndarray of the same shape as a, where the n largest values are set to 1, and the smaller values are set to 0.

rax.utils.approx_ranks(scores, *, where=None, key=None, step_fn=<function sigmoid>)

Computes approximate ranks.

This can be used to construct differentiable approximations of metrics. For example:

>>> import functools
>>> approx_ndcg = functools.partial(
...     rax.ndcg_metric, rank_fn=rax.utils.approx_ranks)
>>> scores = jnp.asarray([-1., 1., 0.])
>>> labels = jnp.asarray([0., 0., 1.])
>>> approx_ndcg(scores, labels)
DeviceArray(0.63092977, dtype=float32)
>>> jax.grad(approx_ndcg)(scores, labels)
DeviceArray([-0.03763788, -0.03763788,  0.07527576], dtype=float32)
Parameters
Return type

ndarray

Returns

A ndarray of the same shape as scores, indicating the 1-based approximate rank of each item.

rax.utils.approx_cutoff(a, n=None, *, where=None, step_fn=<function <lambda>>)

Computes a binary array to select the largest n values of a.

This function computes a binary jax.numpy.ndarray that selects the n largest values of a across its last dimension.

Note that the returned indicator may select more than n items if a has ties.

Parameters
Return type

ndarray

Returns

A jax.numpy.ndarray of the same shape as a, where the n largest values are set to 1, and the smaller values are set to 0.

Types

Rax-specific types and protocols.

Note

Types and protocols are provided for type-checking convenience only. You do not need to instantiate, subclass or extend them.

CutoffFn(*args, **kwds)

typing.Protocol for cutoff functions.

LossFn(*args, **kwds)

typing.Protocol for loss functions.

MetricFn(*args, **kwds)

typing.Protocol for metric functions.

RankFn(*args, **kwds)

typing.Protocol for rank functions.

ReduceFn(*args, **kwds)

typing.Protocol for reduce functions.

class rax.types.CutoffFn(*args, **kwds)

typing.Protocol for cutoff functions.

__call__(a, n)

Computes cutoffs based on the given array.

Parameters
  • a (ndarray) – The array for which to compute the cutoffs.

  • n (Optional[int]) – The position of the cutoff.

Return type

ndarray

Returns

A binary jax.numpy.ndarray of the same shape as a that represents which elements of a should be selected for the topn cutoff.

class rax.types.LossFn(*args, **kwds)

typing.Protocol for loss functions.

__call__(scores, labels, *, where, **kwargs)

Computes a loss.

Parameters
  • scores (ndarray) – The score of each item.

  • labels (ndarray) – The label of each item.

  • where (Optional[ndarray]) – An optional jax.numpy.ndarray of the same shape as scores that indicates which elements to include in the loss.

  • **kwargs – Optional loss-specific keyword arguments.

Return type

ndarray

Returns

A jax.numpy.ndarray that represents the loss computed on the given scores and labels.

class rax.types.MetricFn(*args, **kwds)

typing.Protocol for metric functions.

__call__(scores, labels, *, where, **kwargs)

Computes a metric.

Parameters
  • scores (ndarray) – The score of each item.

  • labels (ndarray) – The label of each item.

  • where (Optional[ndarray]) – An optional jax.numpy.ndarray of the same shape as scores that indicates which elements to include in the metric.

  • **kwargs – Optional metric-specific keyword arguments.

Return type

ndarray

Returns

A jax.numpy.ndarray that represents the metric computed on the given scores and labels.

class rax.types.RankFn(*args, **kwds)

typing.Protocol for rank functions.

__call__(scores, where, key)

Computes 1-based ranks based on the given scores.

Parameters
Return type

ndarray

Returns

A jax.numpy.ndarray of the same shape as scores that represents the 1-based ranks.

class rax.types.ReduceFn(*args, **kwds)

typing.Protocol for reduce functions.

__call__(a, where, axis)

Reduces an array across one or more dimensions.

Parameters
  • a (ndarray) – The array to reduce.

  • where (Optional[ndarray]) – An optional jax.numpy.ndarray of the same shape as a that indicates which elements to include in the reduction.

  • axis (Union[int, Tuple[int, ...], None]) – One or more axes to use for the reduction. If None this reduces across all available axes.

Return type

ndarray

Returns

A jax.numpy.ndarray that represents the reduced result of a over given axis.

References

BHBN20

Sebastian Bruch, Shuguang Han, Michael Bendersky, and Marc Najork. A stochastic treatment of learning to rank scoring functions. In Proceedings of the 13th International Conference on Web Search and Data Mining, 61–69. 2020.

BSR+05

Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. Learning to rank using gradient descent. In Proceedings of the 22nd international conference on Machine learning, 89–96. 2005.

JarvelinKekalainen02(1,2)

Kalervo Järvelin and Jaana Kekäläinen. Cumulated gain-based evaluation of ir techniques. ACM Transactions on Information Systems (TOIS), 20(4):422–446, 2002.

QLL10

Tao Qin, Tie-Yan Liu, and Hang Li. A general approximation framework for direct optimization of information retrieval measures. Information retrieval, 13(4):375–397, 2010.