import logging
import hydra
from fair_dp_sgd.accounting import get_number_of_steps_for_target_epsilon
from fair_dp_sgd.data import get_data_stream
from omegaconf import DictConfig
from jax import random
from fair_dp_sgd.models import get_model
from fair_dp_sgd.training.training_routine import train_and_evaluate
from fair_dp_sgd.utils.cache import cache_results
from matplotlib import pyplot as plt

def compute_mean_gap(metric1, metric2):
    return abs(sum(metric1) - sum(metric2)) / len(metric1)

@hydra.main(version_base=None, config_path="conf", config_name="dpraco.yaml")
def main(cfg: DictConfig):
    key = random.PRNGKey(cfg.training_params.seed)
    data_key, model_key, training_key = random.split(key, num=3)
    (train_stream, val_data, test_data) = get_data_stream(cfg, data_key, seed=cfg.training_params.seed)
    cfg.training_params.number_of_steps = get_number_of_steps_for_target_epsilon(cfg)

    try:
        state = get_model(cfg, model_key)
        metrics = train_and_evaluate(
            cfg=cfg,
            state=state,
            train_stream=train_stream,
            rng=training_key,
            test_data=test_data,
            val_data=val_data
        )
    except:
        import traceback
        import sys
        traceback.print_exception(*sys.exc_info())
        logging.info(f"{cfg} has failed")
        exit(0)

    test_hard_constraint = metrics["test_hard_constraint"].to_numpy()
    test_soft_constraint = metrics["test_soft_constraint"].to_numpy()
    train_hard_constraint = metrics["train_hard_constraint"].to_numpy()
    train_soft_constraint = metrics["train_soft_constraint"].to_numpy()

    train_soft_hard_gap = compute_mean_gap(train_soft_constraint, train_hard_constraint)
    logging.info(f"Train Soft-Hard Gap: {train_soft_hard_gap}")

    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    out_dir = hydra_cfg["runtime"]["output_dir"]

    # Plot Constraints
    plt.figure(figsize=(12, 8))
    k = cfg.eval.eval_every_k
    steps = [(i + 1) * k for i in range(len(test_hard_constraint))]
    plt.plot(steps, test_hard_constraint, label='Test Hard Constraint')
    plt.plot(steps, test_soft_constraint, label='Test Soft Constraint')
    plt.plot(steps, train_hard_constraint, label='Train Hard Constraint')
    plt.plot(steps, train_soft_constraint, label='Train Soft Constraint')
    plt.axhline(cfg.algorithm.gamma, color='r', linestyle='--', label='Gamma')
    plt.xlabel('Training Step')
    plt.ylabel('Constraint Value')
    plt.title('Constraints Over Training Steps')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'{out_dir}/constraints.png')

    cache_results(f"metrics_eps={cfg.training_params.target_epsilon}", metrics)


if __name__ == "__main__":
    main()
