import pickle
import time

import numpy as np
import scipy.stats as stats
import torch

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    ModelConfig,
    NSeedsConfig,
    SeedConfig,
    get_adv_trained_model_name,
    get_legacy_filename_config,
    parse_args_to_dataclass,
)
from bandit2.bandit_attacker import BanditUniformRandomAttacker
from bandit2.bandit_ctrl import BanditTransformerController
from bandit2.bandit_env import BanditEnv
from bandit_algs import get_bandit_algs
from net import Transformer
from utils import build_bandit_model_filename

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(dataset_config: DatasetConfig, model_config: ModelConfig, eval_config: EvalConfig, adv_train_config: AdversarialTrainingConfig, n_seeds_config: NSeedsConfig) -> None:
    algs = ["dpt_frozen", "ts", "rts", "rts_u", "rts_k", "ucb", "crucb", "crucb_v", "crucb_p"]

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_actions = dataset_config.n_actions

    regrets_clean = {alg: [] for alg in algs}
    regrets_unifrand = {alg: [] for alg in algs}

    for seed in range(n_seeds_config.n_seeds):
        transformer_config = model_config.get_params({"H": dataset_config.context_len, "state_dim": 1, "action_dim": dataset_config.n_actions})
        filename = build_bandit_model_filename("bandit", get_legacy_filename_config(model_config, dataset_config, SeedConfig(seed=seed)))
        if eval_config.epoch is None:
            model_path = f"models/{filename}.pt"
        else:
            model_path = f"models/{filename}_epoch{eval_config.epoch}.pt"
        model = Transformer(transformer_config).to(device)

        model.test = True
        model.load_state_dict(torch.load(model_path))
        model.eval()
        print(f"Loaded model {model_path}.")
        dpt_frozen_policy = BanditTransformerController(model, n_envs, n_steps, n_actions, sample=True, frozen=True, device=device)

        env = BanditEnv.sample(n_envs, n_steps, n_actions, dataset_config.variance, device=device)

        policies = get_bandit_algs(
            n_envs,
            n_steps,
            n_actions,
            env.get_optimal_actions(),
            None,
            dpt_frozen_policy,
            adv_train_config.eps_steps,
            0.1,
            adv_train_config.eps_steps,
            adv_train_config.max_poison_diff,
            device=device,
        )

        rewards_opt = env.deploy(policies["opt"]).rewards_original
        for alg in algs:
            rewards_alg = env.deploy(policies[alg]).rewards_original
            regrets_clean[alg].append((rewards_opt - rewards_alg).mean().item() * n_steps)

        attacker = BanditUniformRandomAttacker(n_envs, n_actions, dataset_config.variance, adv_train_config.max_poison_diff, device=device)
        rewards_opt = env.deploy(policies["opt"]).rewards_original
        for alg in algs:
            rewards_alg = env.deploy(policies[alg], attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps).rewards_original
            regrets_unifrand[alg].append((rewards_opt - rewards_alg).mean().item() * n_steps)

    print(f"\nUnif. Rand. Attacker:")
    for alg in algs:
        mean = np.mean(regrets_unifrand[alg])
        confidence = 2 * stats.sem(regrets_unifrand[alg])
        print(f"  {alg: <10}: {mean:.1f} $\\pm$ {confidence:.1f}")

    print(f"\nClean env:")
    for alg in algs:
        mean = np.mean(regrets_clean[alg])
        confidence = 2 * stats.sem(regrets_clean[alg])
        print(f"  {alg: <10}: {mean:.1f} $\\pm$ {confidence:.1f}")

    setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=False)
    evals_path = f"models/adv/{setup_name}/attacker_against_clean_unifrand_evals_seeds{n_seeds_config.n_seeds}.pkl"
    with open(evals_path, "wb") as f:
        pickle.dump({"clean": regrets_clean, "unifrand": regrets_unifrand}, f)
    print(f"Saved to '{evals_path}'.")


if __name__ == "__main__":
    dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config = parse_args_to_dataclass(
        (DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig, NSeedsConfig)
    )

    print(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config, sep="\n")

    time_start = time.time()
    main(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
