import logging

import hydra

from fair_dp_sgd.accounting import get_number_of_steps_for_target_epsilon
from fair_dp_sgd.data import get_data_stream
from omegaconf import DictConfig
from jax import random
from fair_dp_sgd.models import get_model
from fair_dp_sgd.training.training_routine import train_and_evaluate
import gc
from tqdm import tqdm


import pandas
from omegaconf import OmegaConf
import warnings
from pandas.errors import SettingWithCopyWarning

warnings.simplefilter(action='ignore', category=SettingWithCopyWarning)


def get_best_results(folder):
    best_optim_results = f"{folder}/optimization_results.yaml"
    optim_results = OmegaConf.load(best_optim_results)
    optim_results = dict(optim_results["best_params"]).items()
    optim_results = {(str(k), str(v)) for k, v in optim_results}

    for i in range(200):
        file = f"{folder}/{i}/.hydra/hydra.yaml"
        setting = OmegaConf.load(file)
        ovverides = setting["hydra"]["overrides"]["task"]
        results = {}
        for override in ovverides:
            ovveride = override.split("=")
            results[ovveride[0]] = ovveride[1]
        results = results.items()
        results = set(results)
        if optim_results <= results:
            print(f"For {folder} the best results were found.")
            break
    else:
        print(f"For {folder} the best results were not found.")
        return

    return OmegaConf.load(f"{folder}/{i}/.hydra/config.yaml")

@hydra.main(version_base=None, config_path="conf", config_name="rerun_with_best_hparams.yaml")
def main(load_cfg: DictConfig):
    cfg = get_best_results(load_cfg.target_folder)

    if cfg.training_params.number_of_steps == 0:
        return 0

    logging.info(f"Config:  {cfg}")

    gamma = cfg.algorithm.gamma
    metrics_history = []
    del cfg.training_params.seeds
    cfg.training_params.poisson = True

    for seed in tqdm(range(load_cfg.num_runs)):
        try:
            key = random.PRNGKey(seed)
            data_key, model_key, training_key = random.split(key, num=3)
            (train_stream, val_data, test_data) = get_data_stream(cfg, data_key, seed)
            if cfg.algorithm.sigma == 0:
                cfg.training_params.number_of_steps = 5000
            else:
                cfg.training_params.number_of_steps = get_number_of_steps_for_target_epsilon(
                    cfg
                )
            state = get_model(cfg, model_key)
            results = train_and_evaluate(
                cfg=cfg,
                state=state,
                train_stream=train_stream,
                rng=training_key,
                test_data=test_data,
                val_data=val_data,
            )
            train_disparity = results["train_hard_constraint"]
            filtered_results = results[train_disparity - gamma < 0]
            filtered_results["seed"] = seed
            if len(filtered_results) == 0:
                max_accuracy = 0
            else:
                max_accuracy = filtered_results["val_accuracy"].max()

            df = filtered_results[filtered_results["val_accuracy"] == max_accuracy]
            metrics_history.append(df)

            del results
            del train_stream
            del test_data
            del val_data
            gc.collect()
        except:
            import traceback
            import sys
            traceback.print_exception(*sys.exc_info())
            logging.info(f"{cfg} has failed on seed {seed}")
            logging.error(f"{cfg} has failed on seed {seed}")
            exit(0)
    df = pandas.concat(metrics_history)
    df.to_csv(f"{load_cfg.target_folder}/final_test_results.csv")
    return 0


if __name__ == "__main__":
    main()
