import pickle
import time

import numpy as np
import scipy.stats as stats
import torch

from args import NSeedsConfig, PrintAdvAgainstsConfig, parse_args_to_dataclass
from util.argparser_dataclass import parse_args_to_dataclass

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(print_adv_againsts_config: PrintAdvAgainstsConfig, n_seeds_config: NSeedsConfig) -> None:

    alg_victim, alg_att = print_adv_againsts_config.meta.split("-")

    base_dir = f"models/adv/{print_adv_againsts_config.setup_dir}"

    rewards_victim: list[np.ndarray] = []
    for seed in range(n_seeds_config.n_seeds):
        results_filename = f"{base_dir}/attacker_against_{alg_att}_victim_{alg_victim}_{seed}_eval.pkl"
        with open(results_filename, "rb") as f:
            rewards_victim_oneseed: np.ndarray = pickle.load(f)
        rewards_victim.append(rewards_victim_oneseed)

    alg_rewards: list[float] = [arr.sum(-1).mean(-1) for arr in rewards_victim]

    mean = np.mean(alg_rewards)
    confidence = 2 * stats.sem(alg_rewards)

    print("{" + f"{mean:.1f} $\\pm$ {confidence:.1f}" + "}")


if __name__ == "__main__":
    print_adv_againsts_config, n_seeds_config = parse_args_to_dataclass((PrintAdvAgainstsConfig, NSeedsConfig))

    print(print_adv_againsts_config, n_seeds_config, sep="\n")

    time_start = time.time()
    main(print_adv_againsts_config, n_seeds_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
