import pickle
import time
from math import sqrt

import numpy as np
import scipy.stats as stats
import torch

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    ModelConfig,
    NPGConfig,
    NSeedsConfig,
    PPOConfig,
    SeedConfig,
    get_adv_trained_model_name,
    get_model_name,
    get_model_save_name,
    parse_args_to_dataclass,
)
from mdp.darkroom_env import DarkroomEnv
from mdp.mdp_attacker import MDPGridRandomAttacker
from mdp.mdp_controller import MDPNPGController, MDPTransformerController, PPOController
from mdp.mdp_env import MDPController
from mdp_algs import get_mdp_algs
from net import Transformer
from util.argparser_dataclass import parse_args_to_dataclass
from util.seed import set_seed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(dataset_config: DatasetConfig, model_config: ModelConfig, eval_config: EvalConfig, adv_train_config: AdversarialTrainingConfig, n_seeds_config: NSeedsConfig) -> None:
    algs_victim = ["dpt_frozen"]  # NPG and PPO require special handling , "npg", "ppo"

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_steps_eval = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len

    rewards_algs: dict[str, dict[str, list[np.ndarray]]] = {
        "clean": {alg: [] for alg in algs_victim},
        "unifrand": {alg: [] for alg in algs_victim},
    }

    # darkroom only
    state_dim = 2
    n_actions = 5
    n_states = dataset_config.n_states
    square_len = int(sqrt(dataset_config.n_states))

    model_name = get_model_name(dataset_config, model_config)

    for seed in range(n_seeds_config.n_seeds):
        set_seed(seed)
        model_path = f"models/{get_model_save_name(model_name, SeedConfig(seed), model_config.n_epochs, eval_config.epoch)}.pt"
        model = Transformer(model_config.get_params({"H": n_steps, "state_dim": state_dim, "action_dim": n_actions})).to(device)

        model.test = True
        model.load_state_dict(torch.load(model_path))
        model.eval()
        print(f"Loaded model {model_path}.")

        env = DarkroomEnv.sample(n_envs, n_steps_eval, square_len, device=device)

        dpt_frozen_policy = MDPTransformerController(model, n_envs, n_steps, n_states, state_dim, n_actions, sample=True, frozen=True, device=device)

        policies = get_mdp_algs(
            n_envs,
            n_steps,
            n_steps_eval,
            n_states,
            state_dim,
            n_actions,
            env.optimal_actions,
            None,
            dpt_frozen_policy,
            device=device,
        )

        # Clean eval
        for alg_victim in algs_victim:
            dataset_alg = env.deploy(policies[alg_victim])
            rewards_algs["clean"][alg_victim].append(dataset_alg.rewards_original.numpy(force=True))

        # Grid Random eval
        attacker = MDPGridRandomAttacker(n_envs, square_len, adv_train_config.max_poison_diff, device=device)

        for alg_victim in algs_victim:
            dataset_alg = env.deploy(policies[alg_victim], attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
            rewards_algs["unifrand"][alg_victim].append(dataset_alg.rewards_original.numpy(force=True))

    for alg_att in ["clean", "unifrand"]:
        print(f"\n{alg_att} attacker:")
        for alg_victim in algs_victim:
            rewards_alg: list[float] = [arr.sum(-1).mean(-1) for arr in rewards_algs[alg_att][alg_victim]]

            mean = np.mean(rewards_alg)
            confidence = 2 * stats.sem(rewards_alg)

            print(f"  {alg_victim: <10}: {mean:.1f} $\\pm$ {confidence:.1f}")

    setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=False)
    evals_path = f"models/adv/{setup_name}/attacker_against_clean_unifrand_evals_seeds{n_seeds_config.n_seeds}.pkl"
    with open(evals_path, "wb") as f:
        pickle.dump(rewards_algs, f)
    print(f"Saved to '{evals_path}'.")


if __name__ == "__main__":
    dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config = parse_args_to_dataclass(
        (DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig, NSeedsConfig)
    )

    print(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config, sep="\n")

    time_start = time.time()
    main(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
