import pickle
import time

import numpy as np
import scipy.stats as stats
import torch

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    ModelConfig,
    NSeedsConfig,
    SeedConfig,
    get_adv_trained_model_name,
    get_legacy_filename_config,
    parse_args_to_dataclass,
)
from bandit2.bandit_attacker import BanditNaiveAttacker, BanditUniformRandomAttacker
from bandit2.bandit_ctrl import BanditOptimalController, BanditTransformerController
from bandit2.bandit_env import BanditEnv
from bandit_algs import get_bandit_algs
from net import Transformer
from utils import build_bandit_model_filename

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(dataset_config: DatasetConfig, model_config: ModelConfig, eval_config: EvalConfig, adv_train_config: AdversarialTrainingConfig, n_seeds_config: NSeedsConfig) -> None:
    transformer_config = model_config.get_params({"H": dataset_config.context_len, "state_dim": 1, "action_dim": dataset_config.n_actions})

    algs_against = ["dpt", "dpt_frozen", "ts", "rts", "ucb", "crucb", "crucb_p", "crucb_v"]

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_steps_eval = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len
    n_actions = dataset_config.n_actions
    victim_alg = adv_train_config.attacker_against

    setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=False)

    regrets_victim_against: dict[str, list[float]] = {alg: [] for alg in algs_against}
    regrets_victim_against["unifrand"] = []
    regrets_victim_against["clean"] = []

    for seed in range(n_seeds_config.n_seeds):
        dpt_policy = None
        if victim_alg == "dpt":
            atdpt_seed = (seed + 1) % n_seeds_config.n_seeds  # to prevent atdpt being evaluated on same attacker
            atdpt_path = f"models/adv/{setup_name}/atdpt_{atdpt_seed}.pt"
            model = Transformer(transformer_config).to(device)
            model.test = True
            model.load_state_dict(torch.load(atdpt_path))
            model.eval()
            print(f"Loaded model {atdpt_path}.")

            dpt_policy = BanditTransformerController(model, n_envs, n_steps, n_actions, sample=True, device=device)

        dpt_frozen_policy = None
        if victim_alg == "dpt_frozen":
            filename = build_bandit_model_filename("bandit", get_legacy_filename_config(model_config, dataset_config, SeedConfig(seed)))
            if eval_config.epoch is None:
                dpt_frozen_path = f"models/{filename}.pt"
            else:
                dpt_frozen_path = f"models/{filename}_epoch{eval_config.epoch}.pt"
            model = Transformer(transformer_config).to(device)
            model.test = True
            model.load_state_dict(torch.load(dpt_frozen_path))
            model.eval()
            print(f"Loaded model {dpt_frozen_path}.")

            dpt_frozen_policy = BanditTransformerController(model, n_envs, n_steps, n_actions, sample=True, frozen=True, device=device)

        policies = get_bandit_algs(
            n_envs,
            n_steps,
            n_actions,
            None,
            dpt_policy,
            dpt_frozen_policy,
            adv_train_config.eps_steps,
            0.1,
            adv_train_config.eps_steps,
            adv_train_config.max_poison_diff,
            device=device,
        )

        victim_policy = policies[victim_alg]

        for alg_against in algs_against:
            attacker = BanditNaiveAttacker(n_envs, n_actions, torch.zeros((n_envs, n_actions)), device=device)
            attacker_save_path = f"models/adv/{setup_name}/attacker_against_{alg_against}_{seed}.pt"
            attacker.load_state_dict(torch.load(attacker_save_path))
            print(f"Loaded attacker from '{attacker_save_path}'.")

            env = BanditEnv(attacker.means_original, n_steps_eval, dataset_config.variance, device=device)
            optimal_policy = BanditOptimalController(n_envs, n_steps_eval, n_actions, env.get_optimal_actions(), device=device)

            rewards_opt = env.deploy(optimal_policy).rewards_original
            rewards_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps).rewards_original

            regrets_victim_against[alg_against].append((rewards_opt - rewards_victim).mean().item() * n_steps)

        attacker = BanditUniformRandomAttacker(n_envs, n_actions, dataset_config.variance, adv_train_config.max_poison_diff, device=device)
        rewards_opt = env.deploy(optimal_policy).rewards_original
        rewards_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps).rewards_original
        regrets_victim_against["unifrand"].append((rewards_opt - rewards_victim).mean().item() * n_steps)

        rewards_opt = env.deploy(optimal_policy).rewards_original
        rewards_victim = env.deploy(victim_policy).rewards_original
        regrets_victim_against["clean"].append((rewards_opt - rewards_victim).mean().item() * n_steps)

    algs_against += ["unifrand", "clean"]

    print(f"{victim_alg} regret against:")
    for alg_against in algs_against:
        print("{", f"{alg_against: ^14}", end="}", sep="")
    print()
    for alg_against in algs_against:
        mean = np.mean(regrets_victim_against[alg_against])
        confidence = 2 * stats.sem(regrets_victim_against[alg_against])
        print("{", f"{mean: >4.1f} $\\pm$ {confidence:.1f}", end="}", sep="")
    print()

    results_filename = f"models/adv/{setup_name}/attacker_against_all_{victim_alg}_evals_seeds{n_seeds_config.n_seeds}.pkl"
    with open(results_filename, "wb") as f:
        pickle.dump(regrets_victim_against, f)
    print(f"Saved to '{results_filename}'.")


if __name__ == "__main__":
    dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config = parse_args_to_dataclass(
        (DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig, NSeedsConfig)
    )

    print(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config, sep="\n")

    time_start = time.time()
    main(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
