🦖 Rax: Learning-to-Rank using JAX
Rax is a Learning-to-Rank (LTR) library built on top of JAX.
import rax
import jax.numpy as jnp
scores = jnp.array([0.3, 0.8, 0.21])
labels = jnp.array([1., 0., 2.])
rax.pairwise_hinge_loss(scores, labels)
rax.ndcg_metric(scores, labels)
Installation
See https://github.com/google/jax#pip-installation for instructions on installing JAX.
We suggest installing the latest stable version of Rax by running:
$ pip install rax
Contribute
Issue tracker: https://github.com/google/rax/issues
Source code: https://github.com/google/rax/tree/master
Support
If you are having issues, please let us know by filing an issue on our issue tracker.
License
Rax is licensed under the Apache 2.0 License.