import pickle
import time
from dataclasses import asdict

import numpy as np
import scipy.stats as stats
import torch

from args import (
    AdaptiveAttackerConfig,
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    ModelConfig,
    NSeedsConfig,
    SeedConfig,
    get_adaptive_adv_trained_model_name,
    get_adv_trained_model_name,
    get_legacy_filename_config,
    parse_args_to_dataclass,
)
from bandit2.bandit_attacker import (
    BanditAdaptiveAttacker,
    BanditNaiveAttacker,
    BanditUniformRandomAttacker,
)
from bandit2.bandit_ctrl import BanditOptimalController, BanditTransformerController
from bandit2.bandit_env import BanditEnv
from bandit_algs import get_bandit_algs
from net import Transformer
from utils import build_bandit_model_filename

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NONADAPTIVE = "nonadaptive_"


def main(
    dataset_config: DatasetConfig,
    model_config: ModelConfig,
    eval_config: EvalConfig,
    adv_train_config: AdversarialTrainingConfig,
    n_seeds_config: NSeedsConfig,
    adaptive_attacker_config: AdaptiveAttackerConfig,
) -> None:
    transformer_config = model_config.get_params({"H": dataset_config.context_len, "state_dim": 1, "action_dim": dataset_config.n_actions})

    algs_attacker = ["dpt", "ts", f"{NONADAPTIVE}dpt", f"{NONADAPTIVE}ts"]

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_steps_eval = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len
    n_actions = dataset_config.n_actions
    victim_alg = adv_train_config.attacker_against

    setup_name = get_adaptive_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, adaptive_attacker_config, print_against=False)
    nonadaptive_adv_config = AdversarialTrainingConfig(**{**asdict(adv_train_config), "attacker_lr": 0.03, "n_rounds": 20})
    nonadaptive_setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, nonadaptive_adv_config, print_against=False)

    regrets_victim_against: dict[str, list[float]] = {alg: [] for alg in algs_attacker}
    regrets_victim_against["unifrand"] = []
    regrets_victim_against["clean"] = []

    for seed in range(n_seeds_config.n_seeds):
        dpt_policy = None
        if victim_alg == "dpt":
            atdpt_seed = (seed + 1) % n_seeds_config.n_seeds  # to prevent atdpt being evaluated on same attacker
            atdpt_path = f"models/adv/{setup_name}/atdpt_{atdpt_seed}.pt"
            model = Transformer(transformer_config).to(device)
            model.test = True
            model.load_state_dict(torch.load(atdpt_path))
            model.eval()
            print(f"Loaded model {atdpt_path}.")

            dpt_policy = BanditTransformerController(model, n_envs, n_steps, n_actions, sample=True, device=device)
        elif victim_alg == "nonadaptive_dpt":
            atdpt_seed = (seed + 1) % n_seeds_config.n_seeds  # to prevent atdpt being evaluated on same attacker
            atdpt_path = f"models/adv/{nonadaptive_setup_name}/atdpt_{atdpt_seed}.pt"
            model = Transformer(transformer_config).to(device)
            model.test = True
            model.load_state_dict(torch.load(atdpt_path))
            model.eval()
            print(f"Loaded model {atdpt_path}.")

            dpt_policy = BanditTransformerController(model, n_envs, n_steps, n_actions, sample=True, device=device)

        dpt_frozen_policy = None
        if victim_alg == "dpt_frozen":
            filename = build_bandit_model_filename("bandit", get_legacy_filename_config(model_config, dataset_config, SeedConfig(seed)))
            if eval_config.epoch is None:
                dpt_frozen_path = f"models/{filename}.pt"
            else:
                dpt_frozen_path = f"models/{filename}_epoch{eval_config.epoch}.pt"
            model = Transformer(transformer_config).to(device)
            model.test = True
            model.load_state_dict(torch.load(dpt_frozen_path))
            model.eval()
            print(f"Loaded model {dpt_frozen_path}.")

            dpt_frozen_policy = BanditTransformerController(model, n_envs, n_steps, n_actions, sample=True, frozen=True, device=device)

        policies = get_bandit_algs(
            n_envs,
            n_steps,
            n_actions,
            None,
            dpt_policy,
            dpt_frozen_policy,
            adv_train_config.eps_steps,
            0.1,
            adv_train_config.eps_steps,
            adv_train_config.max_poison_diff,
            device=device,
        )
        if victim_alg == f"{NONADAPTIVE}dpt":
            assert dpt_policy is not None
            policies[victim_alg] = dpt_policy

        victim_policy = policies[victim_alg]

        for alg_against in algs_attacker:

            if alg_against.startswith(NONADAPTIVE):
                attacker = BanditNaiveAttacker(n_envs, n_actions, torch.zeros((n_envs, n_actions)), device=device)
                alg_against_nonadaptive = alg_against.removeprefix(NONADAPTIVE)
                attacker_save_path = f"models/adv/{nonadaptive_setup_name}/attacker_against_{alg_against_nonadaptive}_{seed}.pt"
                attacker.load_state_dict(torch.load(attacker_save_path))
                print(f"Loaded attacker from '{attacker_save_path}'.")
            else:
                attacker = BanditAdaptiveAttacker(adaptive_attacker_config, n_envs, n_actions, torch.zeros((n_envs, n_actions)), device=device)
                attacker_save_path = f"models/adv/{setup_name}/attacker_against_{alg_against}_{seed}.pt"
                attacker.load_state_dict(torch.load(attacker_save_path))
                print(f"Loaded attacker from '{attacker_save_path}'.")

            env = BanditEnv(attacker.means_original, n_steps_eval, dataset_config.variance, device=device)
            optimal_policy = BanditOptimalController(n_envs, n_steps_eval, n_actions, env.get_optimal_actions(), device=device)

            dataset_opt = env.deploy(optimal_policy)
            dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)

            regrets_victim_against[alg_against].append((dataset_opt.rewards_original - dataset_victim.rewards_original).mean().item() * n_steps)

        attacker = BanditUniformRandomAttacker(n_envs, n_actions, dataset_config.variance, adv_train_config.max_poison_diff, device=device)
        dataset_opt = env.deploy(optimal_policy)
        dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
        regrets_victim_against["unifrand"].append((dataset_opt.rewards_original - dataset_victim.rewards_original).mean().item() * n_steps)

        dataset_opt = env.deploy(optimal_policy)
        dataset_victim = env.deploy(victim_policy)
        regrets_victim_against["clean"].append((dataset_opt.rewards_original - dataset_victim.rewards_original).mean().item() * n_steps)

    algs_attacker += ["unifrand", "clean"]

    print(f"{victim_alg} regret against:")
    for alg_against in algs_attacker:
        print("{", f"{alg_against: ^14}", end="}", sep="")
    print()
    for alg_against in algs_attacker:
        mean = np.mean(regrets_victim_against[alg_against])
        confidence = 2 * stats.sem(regrets_victim_against[alg_against])
        print("{", f"{mean: >4.1f} $\\pm$ {confidence:.1f}", end="}", sep="")
    print()

    results_filename = f"models/adv/{setup_name}/attacker_against_all_{victim_alg}_evals_seeds{n_seeds_config.n_seeds}.pkl"
    with open(results_filename, "wb") as f:
        pickle.dump(regrets_victim_against, f)
    print(f"Saved to '{results_filename}'.")


if __name__ == "__main__":
    dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config, adaptive_attacker_config = parse_args_to_dataclass(
        (DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig, NSeedsConfig, AdaptiveAttackerConfig)
    )

    print(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config, adaptive_attacker_config, sep="\n")

    time_start = time.time()
    main(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config, adaptive_attacker_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
