Ranking Losses (rax.*_loss)
Implementations of common ranking losses in JAX.
A ranking loss is a differentiable function that expresses the cost of a ranking
induced by item scores compared to a ranking induced from relevance labels. Rax
provides a number of ranking losses as JAX functions that are implemented
according to the LossFn interface.
Loss functions are designed to operate on the last dimension of its inputs. The
leading dimensions are considered batch dimensions. To compute per-list losses,
for example to apply per-list weighting or for distributed computing of losses
across devices, please use standard JAX transformations such as jax.vmap()
or jax.pmap().
Standalone usage:
>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([1., 0., 0.])
>>> loss = rax.softmax_loss(scores, labels)
>>> print(f"{loss:.5f}")
1.40761
Usage with a batch of data and a mask to indicate valid items.
>>> scores = jnp.array([[2., 1., 0.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[1., 0., 0.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> loss = rax.pairwise_hinge_loss(
... scores, labels, where=where, reduce_fn=jnp.mean
... )
>>> print(f"{loss:.5f}")
0.16667
To compute gradients of each loss function, please use standard JAX
transformations such as jax.grad() or jax.value_and_grad():
>>> scores = jnp.asarray([[0., 1., 3.], [1., 2., 0.]])
>>> labels = jnp.asarray([[0., 0., 1.], [1., 0., 0.]])
>>> grads = jax.grad(rax.softmax_loss)(scores, labels, reduce_fn=jnp.mean)
>>> for row in grads:
... print("[" + ", ".join(f"{grad:.5f}" for grad in row) + "]")
[0.02101, 0.05710, -0.07810]
[-0.37764, 0.33262, 0.04502]
|
Mean squared error loss. |
|
Sigmoid cross entropy loss. |
|
Pairwise hinge loss. |
|
Pairwise logistic loss. |
|
Pairwise soft zero-one loss. |
|
Pairwise mean squared error loss. |
|
Pairwise quantile regression loss. |
|
Softmax loss. |
|
ListMLE Loss. |
|
Poly1 softmax loss. |
|
Unique softmax loss. |
- rax.pointwise_mse_loss(scores, labels, *, where=None, segments=None, weights=None, reduce_fn=<function mean>)
Mean squared error loss.
Definition:
\[\ell(s, y) = \sum_i (y_i - s_i)^2 \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The mean squared error loss.
- rax.pointwise_sigmoid_loss(scores, labels, *, where=None, segments=None, weights=None, reduce_fn=<function mean>)
Sigmoid cross entropy loss.
Note
This loss clips label values so that
0 <= label <= 1.Definition:
\[\ell(s, y) = \sum_i y_i * -\log(\op{sigmoid}(s_i)) + (1 - y_i) * -\log(1 - \op{sigmoid}(s_i)) \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional […, list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The sigmoid cross entropy loss.
- rax.pairwise_hinge_loss(scores, labels, *, where=None, segments=None, weights=None, lambdaweight_fn=None, reduce_fn=<function mean>)
Pairwise hinge loss.
Definition:
\[\ell(s, y) = \sum_i \sum_j \II{y_i > y_j} \max(0, 1 - (s_i - s_j)) \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.lambdaweight_fn (
Optional[LambdaweightFn]) – An optional function that outputs lambdaweights.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The pairwise hinge loss.
- rax.pairwise_logistic_loss(scores, labels, *, where=None, segments=None, weights=None, lambdaweight_fn=None, reduce_fn=<function mean>)
Pairwise logistic loss.
Definition [Burges et al., 2005]:
\[\ell(s, y) = \sum_i \sum_j \II{y_i > y_j} \log(1 + \exp(-(s_i - s_j))) \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.lambdaweight_fn (
Optional[LambdaweightFn]) – An optional function that outputs lambdaweights.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The pairwise logistic loss.
- rax.pairwise_soft_zero_one_loss(scores, labels, *, where=None, segments=None, weights=None, lambdaweight_fn=None, reduce_fn=<function mean>)
Pairwise soft zero-one loss.
Definition:
\[\ell(s, y) = \sum_i \sum_j \II{y_i > y_j} \op{sigmoid}(-(s_i - s_j)) \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.lambdaweight_fn (
Optional[LambdaweightFn]) – An optional function that outputs lambdaweights.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The pairwise soft zero-one loss value.
- rax.pairwise_mse_loss(scores, labels, *, where=None, segments=None, weights=None, lambdaweight_fn=None, reduce_fn=<function mean>)
Pairwise mean squared error loss.
Definition:
\[\ell(s, y) = \sum_i \sum_j ((y_i - y_j) - (s_i - s_j))^2 \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.lambdaweight_fn (
Optional[LambdaweightFn]) – An optional function that outputs lambdaweights.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The pairwise mean squared error loss.
- rax.pairwise_qr_loss(scores, labels, *, where=None, segments=None, weights=None, tau=0.5, squared=False, lambdaweight_fn=None, reduce_fn=<function mean>)
Pairwise quantile regression loss.
Definition:
\[\ell(s, y) = \sum_i \sum_j \II{y_i > y_j} \op{loss}_{ij} \\ \op{loss}_{ij} = \tau \max(0, (y_i - y_j) - (s_i - s_j)) + (1-\tau) \max(0, (s_i - s_j) - (y_i - y_j)) \]When
squaredis True, each hinge loss is squared. Please note that only the pairs that have different labels are considered. Whentau= 0.5, this boils down to median regression or mse with some small difference on the tied labels.- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.tau (
float) – A float in (0, 1.0] to define the quantile. When tau = 0.5, it becomes median regression.squared (
bool) – If True, square each individual pairwise loss value.lambdaweight_fn (
Optional[LambdaweightFn]) – An optional function that outputs lambdaweights.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The pairwise quantile regression loss.
- rax.softmax_loss(scores, labels, *, where=None, segments=None, weights=None, label_fn=<function <lambda>>, reduce_fn=<function mean>)
Softmax loss.
Definition:
\[\ell(s, y) = -\sum_i y_i \log \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.label_fn (
Callable[...,Array]) – A label function that maps labels to probabilities. Default keeps labels as-is. Seerax.utils.normalize_probabilities()for example.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The softmax loss.
- rax.listmle_loss(scores, labels, *, key=None, where=None, segments=None, reduce_fn=<function mean>)
ListMLE Loss.
Note
This loss performs sorting using the given labels. If the labels contain multiple identical values, you should provide a
PRNGKey()to thekeyargument to make sure ties are broken randomly during the segments: Optional[Array] = None, sorting operation.Definition [Xia et al., 2008]:
\[\ell(s, y) = -\sum_i \log \frac{\exp(s_i)} {\sum_j \II{\op{rank}(y_j) \ge \op{rank}(y_i)} \exp(s_j)} \]where \(\op{rank}(y_i)\) indicates the rank of item \(i\) after sorting all labels \(y\).
- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.key (
Optional[Array]) – An optionalPRNGKey()to perform random tie-breaking.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The listmle loss.
- rax.poly1_softmax_loss(scores, labels, *, epsilon=1.0, where=None, segments=None, weights=None, reduce_fn=<function mean>)
Poly1 softmax loss.
Definition [Leng et al., 2022]:
\[\ell(s, y) = \op{softmax}(s, y) + \epsilon * (1 - \op{pt}) \]where \(\op{softmax}\) is the standard softmax loss as implemented in
softmax_loss()and \(\op{pt}\) is the target softmax probability defined as:\[\op{pt} = \sum_i \frac{y_i}{\sum_j y_j} \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.epsilon (
float) – A float hyperparameter indicating the weight of the leading polynomial coefficient in the poly loss.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The poly1 softmax loss.
- rax.unique_softmax_loss(scores, labels, *, where=None, segments=None, weights=None, gain_fn=<function default_gain_fn>, reduce_fn=<function mean>)
Unique softmax loss.
Definition [Zhu and Klabjan, 2020]:
\[\ell(s, y) = -\sum_i \op{gain}(y_i) \log \frac{\exp(s_i)}{\exp(s_i) + \sum_{j : y_j < y_i} \exp(s_j)} \]where \(\op{gain}(y_i)\) is a user-specified gain function applied to label \(y_i\) to boost items with higher relevance.
- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.gain_fn (
Optional[Callable[[Array],Array]]) – An optional function that maps relevance labels to gain values. If provided, the per-item losses are multiplied bygain_fn(label)to boost the importance of relevant items.reduce_fn (
ReduceFn) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The unique softmax loss.
Ranking Metrics (rax.*_metric)
Implementations of common ranking metrics in JAX.
A ranking metric expresses how well a ranking induced by item scores matches a
ranking induced from relevance labels. Rax provides a number of ranking metrics
as JAX functions that are implemented according to the
MetricFn interface.
Metric functions are designed to operate on the last dimension of its inputs.
The leading dimensions are considered batch dimensions. To compute per-list
metrics, for example to apply per-list weighting or for distributed computing of
metrics across devices, please use standard JAX transformations such as
jax.vmap() or jax.pmap().
Standalone usage of a metric:
>>> import jax
>>> import rax
>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([2., 0., 1.])
>>> loss = rax.ndcg_metric(scores, labels)
>>> print(f"{loss:.5f}")
0.79671
Usage with a batch of data and a mask to indicate valid items:
>>> scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[2., 0., 1.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> loss = rax.ndcg_metric(scores, labels)
>>> print(f"{loss:.5f}")
0.89835
Usage with jax.vmap() batching and a mask to indicate valid items:
>>> scores = jnp.array([[2., 1., 0.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[1., 0., 0.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> print(jax.vmap(rax.ndcg_metric)(scores, labels, where=where))
[1. 1.]
|
Mean Reciprocal Rank (MRR). |
|
Precision. |
|
Recall. |
|
Average Precision. |
|
Discounted cumulative gain (DCG). |
|
Normalized discounted cumulative gain (NDCG). |
|
Ordered Pair Accuracy (OPA). |
- rax.mrr_metric(scores, labels, *, where=None, segments=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Mean Reciprocal Rank (MRR).
Note
This metric converts graded relevance to binary relevance by considering items with
label >= 1as relevant and items withlabel < 1as non-relevant.Definition:
\[\op{mrr}(s, y) = \max_i \frac{y_i}{\op{rank}(s_i)} \]where \(\op{rank}(s_i)\) indicates the rank of item \(i\) after sorting all scores \(s\) using
rank_fn.- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item. Items for which the score is \(-\inf\) are treated as unranked items.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the metric.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The metric will only be computed on items that share the same segment.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The MRR metric.
- rax.precision_metric(scores, labels, *, where=None, segments=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Precision.
Note
This metric converts graded relevance to binary relevance by considering items with
label >= 1as relevant and items withlabel < 1as non-relevant.Definition:
\[\op{precision@n}(s, y) = \frac{1}{n} \sum_i y_i \cdot \II{\op{rank}(s_i) \leq n} \]where \(\op{rank}(s_i)\) indicates the rank of item \(i\) after sorting all scores \(s\) using
rank_fn.- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item. Items for which the score is \(-\inf\) are treated as unranked items.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the metric.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The metric will only be computed on items that share the same segment.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The precision metric.
- rax.recall_metric(scores, labels, *, where=None, segments=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Recall.
Note
This metric converts graded relevance to binary relevance by considering items with
label >= 1as relevant and items withlabel < 1as non-relevant.Definition:
\[\op{recall@n}(s, y) = \frac{1}{\sum_i y_i} \sum_i y_i \cdot \II{\op{rank}(s_i) \leq n} \]where \(\op{rank}(s_i)\) indicates the rank of item \(i\) after sorting all scores \(s\) using rank_fn.
- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item. Items for which the score is \(-\inf\) are treated as unranked items.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the metric.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The metric will only be computed on items that share the same segment.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The recall metric.
- rax.ap_metric(scores, labels, *, where=None, segments=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Average Precision.
Note
This metric converts graded relevance to binary relevance by considering items with
label >= 1as relevant and items withlabel < 1as non-relevant.Definition:
\[\op{ap}(s, y) = \frac{1}{\sum_i y_i} \sum_i y_i \op{precision@rank}_{s_i}(s, y) \]where \(\op{precision@rank}_{s_i}(s, y)\) indicates the precision at the rank of item \(i\).
- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item. Items for which the score is \(-\inf\) are treated as unranked items.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the metric.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The metric will only be computed on items that share the same segment.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The average precision metric.
- rax.dcg_metric(scores, labels, *, where=None, segments=None, topn=None, weights=None, key=None, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Discounted cumulative gain (DCG).
Definition [Järvelin and Kekäläinen, 2002]:
\[\op{dcg}(s, y) = \sum_i \op{gain}(y_i) \cdot \op{discount}(\op{rank}(s_i)) \]where \(\op{rank}(s_i)\) indicates the 1-based rank of item \(i\) as computed by
rank_fn, \(\op{gain}(y)\) indicates the per-item gains as computed bygain_fn, and, \(\op{discount}(r)\) indicates the per-item rank discounts as computed bydiscount_fn.- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item. Items for which the score is \(-\inf\) are treated as unranked items.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the metric.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The metric will only be computed on items that share the same segment.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the per-item weights.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.gain_fn (
Callable[[Array],Array]) – A function that maps relevance label to gain values.discount_fn (
Callable[[Array],Array]) – A function that maps 1-based ranks to discount values.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The DCG metric.
- rax.ndcg_metric(scores, labels, *, where=None, segments=None, topn=None, weights=None, key=None, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Normalized discounted cumulative gain (NDCG).
Definition [Järvelin and Kekäläinen, 2002]:
\[\op{ndcg}(s, y) = \op{dcg}(s, y) / \op{dcg}(y, y) \]where \(\op{dcg}\) is the discounted cumulative gain metric.
- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item. Items for which the score is \(-\inf\) are treated as unranked items.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the metric.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The metric will only be computed on items that share the same segment.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the per-item weights.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.gain_fn (
Callable[[Array],Array]) – A function that maps relevance label to gain values.discount_fn (
Callable[[Array],Array]) – A function that maps 1-based ranks to discount values.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The NDCG metric.
- rax.opa_metric(scores, labels, *, where=None, reduce_fn=<function mean>)
Ordered Pair Accuracy (OPA).
Definition:
\[\op{opa}(s, y) = \frac{1}{\sum_i \sum_j \II{y_i > y_j}} \sum_i \sum_j \II{s_i > s_j} \II{y_i > y_j} \]Note
Pairs with equal labels (\(y_i = y_j\)) are always ignored. Pairs with equal scores (\(s_i = s_j\)) are considered incorrectly ordered.
- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item. Items for which the score is \(-\inf\) are treated as unranked items.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the metric.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The Ordered Pair Accuracy (OPA).
Function Transformations (rax.*_t12n)
Function transformations for ranking losses and metrics.
These function transformations can be used to transform the ranking metrics and
losses. An example is approx_t12n which transforms a given ranking metric
into a ranking loss by plugging in differentiable approximations to the rank and
cutoff functions.
Example usage:
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> approx_ndcg_loss_fn = rax.approx_t12n(rax.ndcg_metric)
>>> loss = approx_ndcg_loss_fn(scores, labels)
>>> print(f"{loss:.5f}")
-0.71789
|
Transforms |
|
Transforms |
|
Transforms |
|
Transforms |
- rax.approx_t12n(metric_fn, temperature=1.0)
Transforms
metric_fninto an approximate differentiable loss.This transformation and uses a sigmoid approximation to compute ranks and indicators in metrics [Qin et al., 2010]. The returned approximate metric is mapped to negative values to be used as a loss.
Example usage:
>>> approx_mrr = rax.approx_t12n(rax.mrr_metric) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) >>> loss = approx_mrr(scores, labels) >>> print(f"{loss:.5f}") -0.69659
Example usage together with
rax.gumbel_t12n():>>> gumbel_approx_mrr = rax.gumbel_t12n(rax.approx_t12n(rax.mrr_metric)) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) >>> key = jax.random.PRNGKey(42) >>> loss = gumbel_approx_mrr(scores, labels, key=key) >>> print(f"{loss:.5f}") -0.75971
- rax.bound_t12n(metric_fn)
Transforms
metric_fninto a lower-bound differentiable loss.This transformation uses a hinge bound to compute ranks and indicators in metrics. The returned lower-bound of the metric is mapped to negative values to be used as a loss.
Example usage:
>>> bound_mrr = rax.bound_t12n(rax.mrr_metric) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 1., 0., 1.]) >>> loss = bound_mrr(scores, labels) >>> print(f"{loss:.5f}") -0.33333
Example usage together with
rax.gumbel_t12n():>>> gumbel_bound_mrr = rax.gumbel_t12n(rax.bound_t12n(rax.mrr_metric)) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 1., 0., 1.]) >>> loss = gumbel_bound_mrr(scores, labels, key=jax.random.PRNGKey(42)) >>> print(f"{loss:.5f}") -0.40368
- Parameters:
metric_fn (
MetricFn) – The metric function to convert to a lower-bound loss.- Returns:
A loss function that computes the lower-bound version of
metric_fn.
- rax.gumbel_t12n(loss_or_metric_fn, *, samples=8, beta=1.0, smoothing_factor=None)
Transforms
loss_or_metric_fnto operate on Gumbel-sampled scores.This transformation changes given
loss_or_metric_fnso that it samples scores from a Gumbel distribution prior to computing the loss or metric [Bruch et al., 2020]. The returned function requires a newkeykeyword argument.Example usage:
>>> loss_fn = rax.gumbel_t12n(rax.softmax_loss) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) >>> loss = loss_fn(scores, labels, key=jax.random.PRNGKey(42)) >>> print(f"{loss:.5f}") 3.45703 >>> loss = loss_fn(scores, labels, key=jax.random.PRNGKey(79)) >>> print(f"{loss:.5f}") 4.12491
- Parameters:
loss_or_metric_fn (
TypeVar(LossOrMetricFn,LossFn,MetricFn)) – A Rax loss or metric function.samples (
int) – Number of Gumbel samples to create.beta (
float) – Shape of the Gumbel distribution (default 1.0).smoothing_factor (
Optional[float]) – If supplied, this will apply an extralog(softmax(scores) + smoothing_factor)transformation to the scores. If set to 1e-20, this effectively makes the loss compatible with the TF-Ranking versions of Gumbel losses. Ifsmoothing_factor <= 0, this may produceNaNvalues.
- Return type:
- Returns:
A new function that behaves the same as
loss_or_metric_fnbut which requires an additionalkeyargument that will be used to randomly sample the scores from a Gumbel distribution.
- rax.segment_t12n(loss_or_metric_fn)
Transforms
loss_or_metric_fnto operate on segmented inputs.Warning
This transformation incurs an additional \(O(n^2)\) computational cost (where \(n\) is the list size) if no specialized segmented implementation is available for the given
loss_or_metric_fn.This changes the
loss_or_metric_fnto accept an additional keyword argumentsegmentsthat is used to indicate segments to group lists together.
Lambdaweights (rax.*_lambdaweight)
Implementations of lambdaweight functions for Rax pairwise losses.
Lambdaweight functions dynamically adjust the weights of a pairwise loss based
on the scores and labels. Rax provides a number of lambdaweight functions as JAX
functions that are implemented according to the
LambdaweightFn interface.
Example usage:
>>> scores = jnp.array([1.2, 0.4, 1.9])
>>> labels = jnp.array([1.0, 2.0, 0.0])
>>> loss = rax.pairwise_logistic_loss(
... scores, labels, lambdaweight_fn=rax.labeldiff_lambdaweight)
>>> print(f"{loss:.5f}")
1.89237
|
Absolute label difference lambdaweights. |
|
DCG lambdaweights. |
|
DCG v2 ("lambdaloss") lambdaweights. |
- rax.labeldiff_lambdaweight(scores, labels, *, where=None, segments=None, weights=None)
Absolute label difference lambdaweights.
Definition:
\[\lambda_{ij}(s, y) = |y_i - y_j| \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the lambdaweights. Items for which this is False will be ignored when computing the lambdaweights.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The lambdaweights will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.
- Return type:
- Returns:
Absolute label difference lambdaweights.
- rax.dcg_lambdaweight(scores, labels, *, where=None, segments=None, weights=None, topn=None, normalize=False, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>)
DCG lambdaweights.
Definition [Burges et al., 2006]:
\[\lambda_{ij}(s, y) = |\op{gain}(y_i) - \op{gain}(y_j)| \cdot |\op{discount}(\op{rank}(s_i)) - \op{discount}(\op{rank}(s_j))| \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the lambdaweights. Items for which this is False will be ignored when computing the lambdaweights.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The lambdaweights will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.topn (
Optional[int]) – The topn cutoff. IfNone, no cutoff is performed.normalize (
bool) – Whether to use the normalized DCG formulation.gain_fn (
Callable[[Array],Array]) – A function mapping labels to gain values.discount_fn (
Callable[[Array],Array]) – A function mapping ranks to discount values.
- Return type:
- Returns:
DCG lambdaweights.
- rax.dcg2_lambdaweight(scores, labels, *, where=None, segments=None, weights=None, topn=None, normalize=False, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>, light_discount=False)
DCG v2 (“lambdaloss”) lambdaweights.
Definition [Wang et al., 2018]:
\[\lambda_{ij}(s, y) = |\op{gain}(y_i) - \op{gain}(y_j)| \cdot |\op{discount}(|\op{rank}(s_i) - \op{rank}(s_j)|) - \op{discount}(|\op{rank}(s_i) - \op{rank}(s_j)|+1)| \]Or the following when
light_discountisTrue:\[\lambda_{ij}(s, y) = |\op{gain}(y_i) - \op{gain}(y_j)| \cdot |\op{discount}(|\op{rank}(s_i) - \op{rank}(s_j)|)| \]- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the lambdaweights. Items for which this is False will be ignored when computing the lambdaweights.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The lambdaweights will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.topn (
Optional[int]) – The topn cutoff. IfNone, no cutoff is performed. Topn cutoff uses the method described in [Jagerman et al., 2022].normalize (
bool) – Whether to use the normalized DCG formulation.gain_fn (
Callable[[Array],Array]) – A function mapping labels to gain values.discount_fn (
Callable[[Array],Array]) – A function mapping ranks to discount values.light_discount (
bool) – IfTrue, make the position discount light as above.
- Return type:
- Returns:
DCG v2 (“lambdaloss”) lambdaweights.
Utilities
Utility functions for Rax.
|
Computes the ranks for given scores. |
|
Computes a binary array to select the largest |
|
Computes approximate ranks. |
|
Approximately select the largest |
|
Generic pairwise loss. |
|
Computes pairs based on values of a and the given pairwise op. |
|
Reduces the values of given array while preventing NaN in the output. |
|
Normalizes given unscaled probabilities so they sum to one in given axis. |
- rax.utils.ranks(scores, *, where=None, segments=None, axis=-1, key=None)
Computes the ranks for given scores.
Note that the ranks returned by this function are not differentiable due to the sort operation having no gradients.
- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid.segments (
Optional[Array]) – AArrayto indicate segments of items that should be grouped together. Like[0, 0, 1, 0, 2]. The segments may or may not be sorted.axis (
int) – The axis to sort on, by default this is the last axis.key (
Optional[Array]) – An optionaljax.random.PRNGKey(). If provided, ties will be broken randomly using this key. If not provided, ties will retain the order of their appearance in the scores array.
- Return type:
- Returns:
A
Arraywith the same shape as scores that indicates the 1-based rank of each item.
- rax.utils.cutoff(a, n=None, where=None, segments=None)
Computes a binary array to select the largest
nvalues ofa.This function computes a binary
Arraythat selects thenlargest values ofaacross its last dimension.- Parameters:
n (
Optional[int]) – The cutoff value. If None, no cutoff is performed.where (
Optional[Array]) – A mask to indicate which values to include in the topn calculation.segments (
Optional[Array]) – AArrayto indicate segments of items that should be grouped together. Like[0, 0, 1, 0, 2]. The segments may or may not be sorted.
- Return type:
- Returns:
A
Arrayof the same shape asa, where thenlargest values are set to 1, and the smaller values are set to 0.
- rax.utils.approx_ranks(scores, *, where=None, segments=None, key=None, step_fn=<PjitFunction of <function sigmoid>>)
Computes approximate ranks.
This can be used to construct differentiable approximations of metrics. For example:
>>> import functools >>> approx_ndcg = functools.partial( ... rax.ndcg_metric, rank_fn=rax.utils.approx_ranks) >>> scores = jnp.asarray([-1., 1., 0.]) >>> labels = jnp.asarray([0., 0., 1.]) >>> print(f"{approx_ndcg(scores, labels):.5f}") 0.63093 >>> grads = jax.grad(approx_ndcg)(scores, labels) >>> print("[" + ", ".join(f"{grad:.5f}" for grad in grads) + "]") [-0.03764, -0.03764, 0.07528]
- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid.segments (
Optional[Array]) – AArrayto indicate segments of items that should be grouped together. Like[0, 0, 1, 0, 2]. The segments may or may not be sorted.key (
Optional[Array]) – An optionaljax.random.PRNGKey(). Unused byapprox_ranks.step_fn (
Callable[[Array],Array]) – A callable that approximates the step functionx >= 0.
- Return type:
- Returns:
A
Arrayof the same shape asscores, indicating the 1-based approximate rank of each item.
- rax.utils.approx_cutoff(a, n=None, *, where=None, segments=None, step_fn=<PjitFunction of <function sigmoid>>)
Approximately select the largest
nvalues ofa.This function computes a
Arraythat is the probability of an item being in thenlargest values ofaacross its last dimension.- Parameters:
n (
Optional[int]) – The cutoff value. If None, no cutoff is performed.where (
Optional[Array]) – A mask to indicate which values to include in the topn calculation.segments (
Optional[Array]) – AArrayto indicate segments of items that should be grouped together. Like[0, 0, 1, 0, 2]. The segments may or may not be sorted.step_fn (
Callable[[Array],Array]) – A function that computes an approximation ofx >= 0.
- Return type:
- Returns:
A
Arrayof the same shape asa.
- rax.utils.pairwise_loss(scores, labels, *, pair_loss_fn, lambdaweight_fn=None, where=None, segments=None, weights=None, reduce_fn=<function mean>)
Generic pairwise loss.
The
pair_loss_fntakes(scores_diff, labels_diff)and returns the loss for each pair and also the valid pairs considered in the loss.- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.pair_loss_fn (
Callable[[Array,Array],tuple[Array,Array]]) – A function that outputs(pair_losses, valid_pairs)given(scores_diff, labels_diff).lambdaweight_fn (
Optional[LambdaweightFn]) – An optional function that outputs lambdaweights.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.segments (
Optional[Array]) – An optional[..., list_size]-Array, indicating segments within each list. The loss will only be computed on items that share the same segment.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type:
- Returns:
The pairwise loss.
- rax.utils.compute_pairs(a, op)
Computes pairs based on values of a and the given pairwise op.
- Parameters:
- Return type:
- Returns:
A
jax.Arraywith the same leading dimensions as a, but with the last dimension expanded so it includes all pairs op(a[…, i], a[…, j])
- rax.utils.safe_reduce(a, where=None, reduce_fn=None)
Reduces the values of given array while preventing NaN in the output.
For
jax.numpy.mean()reduction, this additionally preventsNaNin the output if all elements are masked. This can happen with pairwise losses where there are no valid pairs because all labels are the same. In this situation, 0 is returned instead.When there is no
reduce_fn, this will set elements ofato 0 according to thewheremask.- Parameters:
- Return type:
- Returns:
The result of reducing the values of
ausing givenreduce_fn.
- rax.utils.normalize_probabilities(unscaled_probabilities, *, where=None, segments=None, axis=-1)
Normalizes given unscaled probabilities so they sum to one in given axis.
This will scale the given unscaled probabilities such that its valid (non-masked) elements will sum to one along the given axis. Note that the array should have only non-negative elements.
For cases where all valid elements along the given axis are zero, this will return a uniform distribution over those valid elements.
For cases where all elements along the given axis are invalid (masked), this will return a uniform distribution over those invalid elements.
- Parameters:
unscaled_probabilities (
Array) – The probabilities to normalize.where (
Optional[Array]) – An optionaljax.Arrayindicating which elements to include in the normalization.segments (
Optional[Array]) – An optionaljax.Arrayto indicate segments of items that should be grouped together. Like[0, 0, 1, 0, 2]. The segments may or may not be sorted.axis (
int) – The axis to normalize on.
- Return type:
- Returns:
Given unscaled probabilities normalized so they sum to one for the valid (non-masked) items in the given axis.
Types
Rax-specific types and protocols.
Note
Types and protocols are provided for type-checking convenience only. You do not need to instantiate, subclass or extend them.
|
|
|
|
|
|
|
|
|
|
|
|
- class rax.types.CutoffFn(*args, **kwargs)
typing.Protocolfor cutoff functions.
- class rax.types.LambdaweightFn(*args, **kwargs)
typing.Protocolfor lambdaweight functions.- __call__(scores, labels, *, where, weights, **kwargs)
Computes lambdaweights.
- Parameters:
scores (
Array) – A[..., list_size]-Array, indicating the score of each item.labels (
Array) – A[..., list_size]-Array, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-Array, indicating which items are valid for computing the lambdaweights. Items for which this is False will be ignored when computing the lambdaweights.weights (
Optional[Array]) – An optional[..., list_size]-Array, indicating the weight for each item.**kwargs – Optional lambdaweight-specific keyword arguments.
- Return type:
- Returns:
A
Arraythat represents the lambda weights.
- class rax.types.LossFn(*args, **kwargs)
typing.Protocolfor loss functions.
- class rax.types.MetricFn(*args, **kwargs)
typing.Protocolfor metric functions.
- class rax.types.RankFn(*args, **kwargs)
typing.Protocolfor rank functions.- __call__(scores, where, key, segments=None)
Computes 1-based ranks based on the given scores.
- Parameters:
scores (
Array) – The scores to compute the 1-based ranks for.where (
Optional[Array]) – An optionalArrayof the same shape asathat indicates which elements to rank. Other elements will be ranked last.key (
Optional[Array]) – An optionalPRNGKey()used for random operations.segments (
Optional[Array]) – An optionalArrayof the same shape asathat indicates which elements to group together.
- Return type:
- Returns:
A
Arrayof the same shape asscoresthat represents the 1-based ranks.
- class rax.types.ReduceFn(*args, **kwargs)
typing.Protocolfor reduce functions.
References
Sebastian Bruch, Shuguang Han, Michael Bendersky, and Marc Najork. A stochastic treatment of learning to rank scoring functions. In Proceedings of the 13th International Conference on Web Search and Data Mining, 61–69. 2020.
Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. Learning to rank using gradient descent. In Proceedings of the 22nd international conference on Machine learning, 89–96. 2005.
Christopher Burges, Robert Ragno, and Quoc Le. Learning to rank with nonsmooth cost functions. In Advances in Neural Information Processing Systems, volume 19, 193–200. 2006.
Rolf Jagerman, Zhen Qin, Xuanhui Wang, Mike Bendersky, and Marc Najork. On optimizing top-k metrics for neural ranking models. In Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval, 2303–2307. 2022.
Kalervo Järvelin and Jaana Kekäläinen. Cumulated gain-based evaluation of ir techniques. ACM Transactions on Information Systems (TOIS), 20(4):422–446, 2002.
Zhaoqi Leng, Mingxing Tan, Chenxi Liu, Ekin Dogus Cubuk, Jay Shi, Shuyang Cheng, and Dragomir Anguelov. Polyloss: a polynomial expansion perspective of classification loss functions. In International Conference on Learning Representations. 2022.
Tao Qin, Tie-Yan Liu, and Hang Li. A general approximation framework for direct optimization of information retrieval measures. Information retrieval, 13(4):375–397, 2010.
Xuanhui Wang, Cheng Li, Nadav Golbandi, Mike Bendersky, and Marc Najork. The lambdaloss framework for ranking metric optimization. In Proceedings of The 27th ACM International Conference on Information and Knowledge Management, 1313–1322. 2018.
Fen Xia, Tie-Yan Liu, Jue Wang, Wensheng Zhang, and Hang Li. Listwise approach to learning to rank: theory and algorithm. In Proceedings of the 25th international conference on Machine learning, 1192–1199. 2008.
Xiaofeng Zhu and Diego Klabjan. Listwise learning to rank by exploring unique ratings. In Proceedings of the 13th international conference on web search and data mining, 798–806. 2020.