import logging

import hydra

from fair_dp_sgd.accounting import get_number_of_steps_for_target_epsilon
from fair_dp_sgd.data import get_data_stream
from omegaconf import DictConfig
from jax import random
import numpy as np
from fair_dp_sgd.models import get_model
from fair_dp_sgd.training.training_routine import train_and_evaluate
import gc

@hydra.main(
    version_base=None, config_path="conf", config_name="finetune_equal_opportunity"
)
def main(cfg: DictConfig):
    # Store test accuracy corresponding to the best validation accuracy for each seed
    test_acc_results = []
    gamma = cfg.algorithm.gamma
    seeds = [42, 139, 563, 2189, 7104]

    for i, seed in enumerate(seeds):
        try:
            key = random.PRNGKey(seed)
            data_key, model_key, training_key = random.split(key, num=3)
            (train_stream, val_data, test_data) = get_data_stream(cfg, data_key, seed)

            if cfg.algorithm.sigma == 0:
                cfg.training_params.number_of_steps = 30000
            else:
                cfg.training_params.number_of_steps = get_number_of_steps_for_target_epsilon(
                    cfg
                )

            if cfg.training_params.number_of_steps == 0:
                return 0

            logging.info(f"Config:  {cfg}")

            state = get_model(cfg, model_key)
            print(f"Training with {cfg.training_params.number_of_steps} steps")
            results = train_and_evaluate(
                cfg=cfg,
                state=state,
                train_stream=train_stream,
                rng=training_key,
                test_data=test_data,
                val_data=val_data,
            )
            train_disparity = results["val_hard_constraint"]
            filtered_results = results[train_disparity - gamma < 0]
            filtered_results["seed"] = seed

            if len(filtered_results) == 0:
                best_test_acc = 0.0  # No fair iteration found
            else:
                # Index of the iteration with the best validation accuracy
                best_idx = filtered_results["val_accuracy"].idxmax()
                best_test_acc = float(filtered_results.loc[best_idx, "test_accuracy"])

            test_acc_results.append(best_test_acc)

            del results
            del train_stream
            del test_data
            del val_data
            gc.collect()
        except:
            import traceback
            import sys
            traceback.print_exception(*sys.exc_info())
            test_acc_results.append(0.0)
            logging.info(f"{cfg} has failed on seed {seed}")
            logging.error(f"{cfg} has failed on seed {seed}")
            test_acc_results.append(0.0)

    # Aggregate statistics across seeds
    mean_test_acc = np.mean(test_acc_results)
    se_test_acc = (
        np.std(test_acc_results, ddof=1) / np.sqrt(len(test_acc_results))
        if len(test_acc_results) > 1
        else 0.0
    )

    logging.info(
        f"Mean test accuracy (at best val accuracy) across seeds: {mean_test_acc:.4f}, "
        f"Standard error: {se_test_acc:.4f}"
    )

    return mean_test_acc


if __name__ == "__main__":
    main()
