import pickle
import time

import numpy as np
import scipy.stats as stats
import torch

from args import NSeedsConfig, PrintAdvAgainstsConfig, parse_args_to_dataclass
from util.argparser_dataclass import parse_args_to_dataclass

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(print_adv_againsts_config: PrintAdvAgainstsConfig, n_seeds_config: NSeedsConfig) -> None:
    base_dir = f"models/adv/{print_adv_againsts_config.setup_dir}"

    algs = ["dpt", "dpt_frozen", "unifrand"]
    single_algs = ["unifrand"]
    if "epss0.4" in print_adv_againsts_config.setup_dir:
        algs += ["clean"]
        single_algs += ["clean"]

    rewards_victim_against: dict[str, list[np.ndarray]] = {alg: [] for alg in algs + ["clean"]}
    for seed in range(n_seeds_config.n_seeds):
        results_filename = f"{base_dir}/attacker_against_all_ppo_evals_seed{seed}.pkl"
        with open(results_filename, "rb") as f:
            rewards_victim_against_algs: dict[str, np.ndarray] = pickle.load(f)
        for alg, rewards in rewards_victim_against_algs.items():
            if alg in ["unifrand", "clean"]:
                continue
            rewards_victim_against[alg].append(rewards)

        for alg in single_algs:
            results_filename = f"{base_dir}/attacker_against_{alg}_ppo_evals_seed{seed}.pkl"
            with open(results_filename, "rb") as f:
                rewards_victim_against_single: np.ndarray = pickle.load(f)
            rewards_victim_against[alg].append(rewards_victim_against_single)

    print(f"PPO reward against:")
    for alg_against in algs:
        print("{", f"{alg_against: ^17}", end="}", sep="")
    print()
    for alg_against in algs:
        alg_rewards: list[float] = [arr.sum(-1).mean(-1) for arr in rewards_victim_against[alg_against]]

        mean = np.mean(alg_rewards)
        confidence = 2 * stats.sem(alg_rewards)

        print("{" + f"{mean:.1f} $\\pm$ {confidence:.1f}" + "}", end="")
    print()


if __name__ == "__main__":
    print_adv_againsts_config, n_seeds_config = parse_args_to_dataclass((PrintAdvAgainstsConfig, NSeedsConfig))

    print(print_adv_againsts_config, n_seeds_config, sep="\n")

    time_start = time.time()
    main(print_adv_againsts_config, n_seeds_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
