Flax Integration
This example demonstrates how to use Flax to build and train a neural ranking model that is optimized with a Rax loss and evaluated with Rax metrics.
Instructions
Clone the Rax repo:
git clone git@github.com:google/rax.git
cd rax
Install the example dependencies:
pip install -r requirements/requirements-examples.txt
And then run the example code:
python examples/flax_integration/main.py
You should see the following expected output in the form of a JSON dictionary containing the per-epoch loss and ranking metrics.
[
{
"epoch": 1,
"loss": 371.3304748535156,
"metric/mrr": 0.8062829971313477,
"metric/ndcg": 0.6677320003509521,
"metric/ndcg@10": 0.4055347740650177
},
{
"epoch": 2,
"loss": 370.42974853515625,
"metric/mrr": 0.8242350220680237,
"metric/ndcg": 0.6812514662742615,
"metric/ndcg@10": 0.43049752712249756
},
{
"epoch": 3,
"loss": 370.25244140625,
"metric/mrr": 0.8261540532112122,
"metric/ndcg": 0.6834192276000977,
"metric/ndcg@10": 0.4342570900917053
}
]