import logging
import os

os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=True "
)

import hydra
from fair_dp_sgd.data import get_data_stream
from omegaconf import DictConfig
from jax import random
from fair_dp_sgd.models import get_model
from fair_dp_sgd.training.benchmark_time import train_and_evaluate
import numpy as np

def compute_mean_gap(metric1, metric2):
    return abs(sum(metric1) - sum(metric2)) / len(metric1)

@hydra.main(version_base=None, config_path="conf", config_name="dpraco.yaml")
def main(cfg: DictConfig):
    key = random.PRNGKey(cfg.training_params.seed)
    data_key, model_key, training_key = random.split(key, num=3)
    (train_stream, test_data, val_data) = get_data_stream(cfg, data_key, seed=cfg.training_params.seed)
    cfg.training_params.number_of_steps = 510
    state = get_model(cfg, model_key)
    times = train_and_evaluate(
        cfg=cfg,
        state=state,
        train_stream=train_stream,
        rng=training_key,
        test_data=test_data,
        val_data=val_data
    )
    times = times[10:] # discard the first 10 runs
    mean = np.mean(times)
    var = np.std(times)
    logging.info(f"Mean - {mean} Var - {var}")
    logging.info(times)

if __name__ == "__main__":
    main()
