import os
import pickle
import time

import numpy as np
import scipy.stats as stats

from args import NSeedsConfig, PrintAdvAgainstsConfig, parse_args_to_dataclass


def main(print_adv_againsts_config: PrintAdvAgainstsConfig, n_seeds_config: NSeedsConfig) -> None:
    algs_against = ["dpt", "dpt_frozen", "ts", "rts", "ucb", "crucb", "crucb_v", "crucb_p"]
    algs_victim = ["dpt", "dpt_frozen", "ts", "rts", "rts_u", "rts_k", "ucb", "crucb", "crucb_v", "crucb_p"]
    algs_eval_individually = algs_victim

    n_steps = print_adv_againsts_config.n_steps
    base_dir = f"models/adv/{print_adv_againsts_config.setup_dir}"

    means = {alg: {} for alg in algs_victim}
    sems2 = {alg: {} for alg in algs_victim}

    ####################
    # After training evals
    ####################
    if False:
        for alg_against in algs_against:
            regrets = {alg: [] for alg in algs_victim}

            for seed in range(n_seeds_config.n_seeds):
                eval_path = f"{base_dir}/attacker_against_{alg_against}_{seed}_evals.pkl"
                if not os.path.exists(eval_path):
                    print(f"File '{eval_path}' does not exist.")
                    continue
                with open(eval_path, "rb") as f:
                    eval_data: dict[str, np.ndarray] = pickle.load(f)

                for alg_victim in algs_victim:
                    if alg_victim not in eval_data:
                        continue
                    regrets[alg_victim].append((eval_data["opt"] - eval_data[alg_victim]).mean() * n_steps)

            for alg_victim in algs_victim:
                if alg_victim in regrets:
                    mean = np.mean(regrets[alg_victim])
                    confidence = 2 * stats.sem(regrets[alg_victim])
                else:
                    mean = -1
                    confidence = -1

                means[alg_victim][alg_against] = mean
                sems2[alg_victim][alg_against] = confidence

    ####################
    # Individual evals
    ####################
    for alg_victim in algs_eval_individually:
        eval_path = f"{base_dir}/attacker_against_all_{alg_victim}_evals_seeds{n_seeds_config.n_seeds}.pkl"
        with open(eval_path, "rb") as f:
            regret_data: dict[str, list[float]] = pickle.load(f)

        for alg_against in algs_against:
            if alg_against in regret_data:
                mean = np.mean(regret_data[alg_against])
                confidence = 2 * stats.sem(regret_data[alg_against])
            else:
                mean = -1
                confidence = -1

            means[alg_victim][alg_against] = mean
            sems2[alg_victim][alg_against] = confidence

    ####################
    # The actual print
    ####################
    for alg_victim in algs_victim:
        print(f"{alg_victim: <10}:")
        for alg_against in algs_against:
            mean = means[alg_victim][alg_against]
            confidence = sems2[alg_victim][alg_against]
            print("{", f"{mean:.1f} $\\pm$ {confidence:.1f}", sep="", end="}")
        print()


if __name__ == "__main__":
    print_adv_againsts_config, n_seeds_config = parse_args_to_dataclass((PrintAdvAgainstsConfig, NSeedsConfig))

    print(print_adv_againsts_config, n_seeds_config, sep="\n")

    time_start = time.time()
    main(print_adv_againsts_config, n_seeds_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
