🦖 Rax Documentation
Rax is a Learning-to-Rank (LTR) library built on top of JAX.
import rax import jax.numpy as jnp scores = jnp.array([0.3, 0.8, 0.21]) labels = jnp.array([1., 0., 2.]) rax.pairwise_hinge_loss(scores, labels) rax.ndcg_metric(scores, labels)
See https://github.com/google/jax#pip-installation for instructions on installing JAX.
We suggest installing the latest stable version of Rax by running:
$ pip install rax
If you are having issues, please let us know by filing an issue on our issue tracker.
Rax is licensed under the Apache 2.0 License.