import os

from fair_dp_sgd.accounting import get_number_of_steps_for_target_epsilon

os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=True "
)
import hydra

from fair_dp_sgd.data import get_data_stream
from omegaconf import DictConfig
from jax import random

from fair_dp_sgd.models import get_model
from fair_dp_sgd.training.training_routine import train_and_evaluate
from fair_dp_sgd.utils.cache import cache_results


@hydra.main(version_base=None, config_path="conf", config_name="dpsgd_fnr.yaml")
def main(cfg: DictConfig):
    key = random.PRNGKey(cfg.training_params.seed)
    cfg.training_params.number_of_steps = 5000
    Cs = [0.1, 0.25, 0.5, 1, 2, 3, 4, 6, 8, 12, 14, 16, 20]
    cfg.algorithm.sigma = 0
    cfg.algorithm.use_non_private_histogram = True
    # split key
    for C in Cs:
        cfg.algorithm.C = C
        data_key, model_key, training_key = random.split(key, num=3)
        (train_stream, regularizer_stream, test_data) = get_data_stream(cfg, data_key)
        state = get_model(cfg, model_key)
        metrics = train_and_evaluate(
            cfg=cfg,
            state=state,
            train_stream=train_stream,
            rng=training_key,
            test_data=test_data,
            regularizer_stream=regularizer_stream,
        )
        cache_results(
            f"metrics_eps={cfg.training_params.target_epsilon}_C={C}", metrics
        )


if __name__ == "__main__":
    main()
