🦖 Rax Documentation

Rax is a Learning-to-Rank (LTR) library built on top of JAX.

import rax
import jax.numpy as jnp

scores = jnp.array([0.3, 0.8, 0.21])
labels = jnp.array([1., 0., 2.])

rax.pairwise_hinge_loss(scores, labels)
rax.ndcg_metric(scores, labels)


See https://github.com/google/jax#pip-installation for instructions on installing JAX.

We suggest installing the latest stable version of Rax by running:

$ pip install rax



If you are having issues, please let us know by filing an issue on our issue tracker.


Rax is licensed under the Apache 2.0 License.