import pickle
import time

import numpy as np
import scipy.stats as stats

from args import NSeedsConfig, PrintAdvAgainstsConfig, parse_args_to_dataclass


def main(print_adv_againsts_config: PrintAdvAgainstsConfig, n_seeds_config: NSeedsConfig) -> None:
    algs_att = ["dpt", "dpt_frozen", "npg", "ql", "unifrand", "clean"]
    algs_victim = ["dpt", "dpt_frozen", "npg", "ql"]

    base_dir = f"models/adv/{print_adv_againsts_config.setup_dir}"

    means = {alg: {} for alg in algs_victim}
    sems2 = {alg: {} for alg in algs_victim}

    # all attacks on dpt_frozen, clean, unifrand from training
    # for alg_att in [alg for alg in algs_att if (alg != "clean" and alg != "unifrand")]:
    #     rewards: dict[str, list[float]] = {alg: [] for alg in algs_victim}

    #     for seed in range(n_seeds_config.n_seeds):
    #         eval_path = f"{base_dir}/attacker_against_{alg_att}_{seed}_evals.pkl"
    #         print(f"Loading '{eval_path}'...")
    #         with open(eval_path, "rb") as f:
    #             eval_data: dict[str, np.ndarray] = pickle.load(f)

    #         for alg_victim in [alg for alg in algs_victim if alg != "dpt"]:
    #             rewards[alg_victim].append(eval_data[alg_victim].sum(-1).mean(-1).item())

    #     for alg_victim in [alg for alg in algs_victim if alg != "dpt"]:
    #         mean = np.mean(rewards[alg_victim])
    #         confidence = 2 * stats.sem(rewards[alg_victim])

    #         means[alg_victim][alg_att] = mean
    #         sems[alg_victim][alg_att] = confidence

    # Cell-based evals
    for alg_victim in ["ppo"]:
        continue
        for alg_att in algs_att:
            results_filename = f"{base_dir}/attacker_against_{alg_att}_victim_{alg_victim}_evals_seeds{n_seeds_config.n_seeds}.pkl"
            with open(results_filename, "rb") as f:
                rewards_victim: list[np.ndarray] = pickle.load(f)
            alg_rewards: list[float] = [arr.sum(-1).mean(-1) for arr in rewards_victim]

            mean = np.mean(alg_rewards)
            confidence = 2 * stats.sem(alg_rewards)

            means[alg_victim][alg_att] = mean
            sems2[alg_victim][alg_att] = confidence

    # Individual evals allseeds
    for alg_victim in ["dpt", "dpt_frozen"]:
        for alg_att in ["ql"]:
            results_filename = f"{base_dir}/attacker_against_{alg_att}_victim_{alg_victim}_evals_seeds{n_seeds_config.n_seeds}.pkl"
            with open(results_filename, "rb") as f:
                rewards_victim: list[np.ndarray] = pickle.load(f)
            alg_rewards: list[float] = [arr.sum(-1).mean(-1) for arr in rewards_victim]

            mean = np.mean(alg_rewards)
            confidence = 2 * stats.sem(alg_rewards)

            means[alg_victim][alg_att] = mean
            sems2[alg_victim][alg_att] = confidence

    # Individual evals individual seeds
    for alg_victim in ["npg"]:
        for alg_att in algs_att:
            rewards_victim: list[np.ndarray] = []
            for seed in range(n_seeds_config.n_seeds):
                results_filename = f"{base_dir}/attacker_against_{alg_att}_victim_{alg_victim}_{seed}_eval.pkl"
                with open(results_filename, "rb") as f:
                    rewards_victim_oneseed: np.ndarray = pickle.load(f)
                rewards_victim.append(rewards_victim_oneseed)

            alg_rewards: list[float] = [arr.sum(-1).mean(-1) for arr in rewards_victim]

            mean = np.mean(alg_rewards)
            confidence = 2 * stats.sem(alg_rewards)

            means[alg_victim][alg_att] = mean
            sems2[alg_victim][alg_att] = confidence

    # separate clean and unifrand attack on dpt_frozen
    # evals_path = f"{base_dir}/attacker_against_clean_unifrand_evals_seeds{n_seeds_config.n_seeds}.pkl"
    # print(f"Loading '{evals_path}'...")
    # with open(evals_path, "rb") as f:
    #     clean_unifrand: dict[str, dict[str, list[np.ndarray]]] = pickle.load(f)
    # for alg_att in ["clean", "unifrand"]:
    #     for alg_victim in ["dpt_frozen"]:
    #         alg_rewards: list[float] = [arr.sum(-1).mean(-1) for arr in clean_unifrand[alg_att][alg_victim]]

    #         mean = np.mean(alg_rewards)
    #         confidence = 2 * stats.sem(alg_rewards)

    #         means[alg_victim][alg_att] = mean
    #         sems[alg_victim][alg_att] = confidence

    # victim (row-based) evals
    for alg_victim in ["dpt", "dpt_frozen", "ql"]:
        results_filename = f"{base_dir}/attacker_against_all_{alg_victim}_evals_seeds{n_seeds_config.n_seeds}.pkl"
        print(f"Loading '{results_filename}'...")
        with open(results_filename, "rb") as f:
            rewards_victim_against: dict[str, list[np.ndarray]] = pickle.load(f)
        for alg_att in [alg for alg in algs_att if (alg != "ql" or alg_victim == "ql")]:
            alg_rewards: list[float] = [arr.sum(-1).mean(-1) for arr in rewards_victim_against[alg_att]]

            mean = np.mean(alg_rewards)
            confidence = 2 * stats.sem(alg_rewards)

            means[alg_victim][alg_att] = mean
            sems2[alg_victim][alg_att] = confidence

    ####################
    # The actual print
    ####################
    print(f"{'': <10}:", end=" ")
    for alg_att in algs_att:
        print("{ " + f"{alg_att: <14}", end="}")
    print()

    for alg_victim in algs_victim:
        print(f"{alg_victim: <10}:", end=" ")
        for alg_att in algs_att:
            print("{", f"{means[alg_victim][alg_att]:.1f} $\\pm$ {sems2[alg_victim][alg_att]:.1f}", sep="", end="}")
        print()


if __name__ == "__main__":
    print_adv_againsts_config, n_seeds_config = parse_args_to_dataclass((PrintAdvAgainstsConfig, NSeedsConfig))

    print(print_adv_againsts_config, n_seeds_config, sep="\n")

    time_start = time.time()
    main(print_adv_againsts_config, n_seeds_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
