🦖 Rax: Learning-to-Rank using JAX

Rax is a Learning-to-Rank (LTR) library built on top of JAX.

import rax
import jax.numpy as jnp

scores = jnp.array([0.3, 0.8, 0.21])
labels = jnp.array([1., 0., 2.])

rax.pairwise_hinge_loss(scores, labels)
rax.ndcg_metric(scores, labels)

Installation

See https://github.com/google/jax#pip-installation for instructions on installing JAX.

We suggest installing the latest stable version of Rax by running:

$ pip install rax

Contribute

Support

If you are having issues, please let us know by filing an issue on our issue tracker.

License

Rax is licensed under the Apache 2.0 License.