import logging
import os

os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=True "
)

import hydra
from fair_dp_sgd.data import get_data_stream
from omegaconf import DictConfig
from jax import random
from fair_dp_sgd.models import get_model
from fair_dp_sgd.training.training_routine import train_and_evaluate
from fair_dp_sgd.utils.cache import cache_results
from matplotlib import pyplot as plt

def compute_mean_gap(metric1, metric2):
    return abs(sum(metric1) - sum(metric2)) / len(metric1)

@hydra.main(version_base=None, config_path="conf", config_name="dpraco.yaml")
def main(cfg: DictConfig):
    key = random.PRNGKey(cfg.training_params.seed)

    temperatures = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
    cfg.algorithm.sigma = 0
    cfg.algorithm.use_non_private_histogram = True
    results = []
    for i in range(len(temperatures)):
        cfg.training_params.number_of_steps = 25000 #get_number_of_steps_for_target_epsilon(cfg)
        cfg.training_params.softmax_temperature = temperatures[i]

        # split key
        data_key, model_key, training_key = random.split(key, num=3)
        (train_stream, regularizer_stream, test_data) = get_data_stream(cfg, data_key)
        state = get_model(cfg, model_key)
        metrics, constraints = train_and_evaluate(
            cfg=cfg,
            state=state,
            train_stream=train_stream,
            rng=training_key,
            test_data=test_data,
            regularizer_stream=regularizer_stream,
        )


        test_hard_constraint = [entry['test_hard_constraint'] for entry in metrics]
        test_soft_constraint = [entry['test_soft_constraint'] for entry in metrics]
        train_hard_constraint = [entry['train_hard_constraint'] for entry in metrics]
        train_soft_constraint = [entry['train_soft_constraint'] for entry in metrics]
        # compute the mean gap between the train_soft_constraint and the train_hard_constraint
        train_soft_hard_gap = compute_mean_gap(train_soft_constraint, train_hard_constraint)
        test_soft_hard_gap = compute_mean_gap(test_soft_constraint, test_hard_constraint)
        logging.info(f"For temperature: {cfg.training_params.softmax_temperature}: train soft-hard Gap: {train_soft_hard_gap}, test soft-hard gap: {test_soft_hard_gap}")
        results.append((train_soft_hard_gap, test_soft_hard_gap))

    cache_results(f"results", results)

if __name__ == "__main__":
    main()
