import pickle
import time
from dataclasses import asdict
from math import sqrt

import numpy as np
import scipy.stats as stats
import torch
from tqdm import tqdm

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    ModelConfig,
    NSeedsConfig,
    SeedConfig,
    get_adv_trained_model_name,
    get_model_name,
    get_model_save_name,
    parse_args_to_dataclass,
)
from mdp.darkroom_env import DarkroomEnv, get_optimal_action
from mdp.mdp_attacker import (
    MDPAttacker,
    MDPGridClassifierAttacker,
    MDPGridRandomAttacker,
)
from mdp.mdp_controller import MDPTransformerController
from mdp_algs import get_mdp_algs
from net import Transformer
from util.argparser_dataclass import parse_args_to_dataclass
from util.seed import set_seed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(dataset_config: DatasetConfig, model_config: ModelConfig, eval_config: EvalConfig, adv_train_config: AdversarialTrainingConfig, n_seeds_config: NSeedsConfig) -> None:
    if dataset_config.env == "darkroom":
        algs_against = ["dpt", "dpt_frozen", "npg", "ql"]  # , "ppo"
    else:
        algs_against = ["dpt", "dpt_frozen"]

    online_alg_episodes = 100  # hardcoded for now

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_steps_eval = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len
    n_actions = dataset_config.n_actions
    victim_alg = adv_train_config.attacker_against

    if dataset_config.env == "chain":
        state_dim = 1
        n_actions = 2
        raise NotImplementedError()
        # env = ChainEnv.sample(n_envs, n_steps_eval, n_states, dataset_config.variance, device=device)
    elif dataset_config.env == "darkroom":
        state_dim = 2
        n_actions = 5
        square_len = int(sqrt(dataset_config.n_states))

        n_states = dataset_config.n_states
    else:
        raise NotImplementedError()

    all_states = torch.stack(
        [
            torch.arange(square_len, device=device).repeat_interleave(square_len),
            torch.arange(square_len, device=device).repeat(square_len),
        ],
        dim=1,
    )

    setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=False)

    rewards_victim_against: dict[str, list[np.ndarray]] = {alg: [] for alg in algs_against}
    rewards_victim_against["unifrand"] = []
    rewards_victim_against["clean"] = []

    model_name = get_model_name(dataset_config, model_config)
    for seed in range(n_seeds_config.n_seeds):
        set_seed(seed)
        dpt_policy = None
        dpt_frozen_policy = None
        if victim_alg == "dpt":
            atdpt_seed = (seed + 1) % n_seeds_config.n_seeds  # to prevent atdpt being evaluated on same attacker
            atdpt_path = f"models/adv/{setup_name}/atdpt_{atdpt_seed}.pt"
            model = Transformer(model_config.get_params({"H": n_steps, "state_dim": state_dim, "action_dim": n_actions})).to(device)
            model.test = True
            model.load_state_dict(torch.load(atdpt_path))
            model.eval()
            print(f"Loaded model {atdpt_path}.")

            dpt_policy = MDPTransformerController(model, n_envs, n_steps, n_states, state_dim, n_actions, sample=True, lr=adv_train_config.victim_lr, device=device)
        elif victim_alg == "dpt_frozen":
            model_path = f"models/{get_model_save_name(model_name, SeedConfig(seed), model_config.n_epochs, eval_config.epoch)}.pt"
            model = Transformer(model_config.get_params({"H": n_steps, "state_dim": state_dim, "action_dim": n_actions})).to(device)
            model.test = True
            model.load_state_dict(torch.load(model_path))
            model.eval()
            print(f"Loaded model {model_path}.")

            dpt_frozen_policy = MDPTransformerController(model, n_envs, n_steps, n_states, state_dim, n_actions, sample=True, frozen=True, device=device)

        policies = get_mdp_algs(
            n_envs,
            n_steps,
            n_steps_eval,
            n_states,
            state_dim,
            n_actions,
            None,
            dpt_policy,
            dpt_frozen_policy,
            device=device,
        )

        victim_policy = policies[victim_alg]

        attacker: MDPAttacker | None = None
        for alg_against in algs_against:
            dummy_env = DarkroomEnv.sample(n_envs, n_steps_eval, square_len, device=device)
            attacker = MDPGridClassifierAttacker(dummy_env.reward_map, dummy_env.goals, n_envs, n_actions, device=device)
            attacker_setup_name = setup_name
            if alg_against == "ql":
                # FIXME: temporary for ql
                attacker_setup_name = get_adv_trained_model_name(
                    dataset_config, model_config, eval_config, AdversarialTrainingConfig(**{**asdict(adv_train_config), "attacker_iters": 1, "n_rounds": 5}), print_against=False
                )
            elif alg_against == "npg":
                # FIXME: temporary for npg
                attacker_setup_name = get_adv_trained_model_name(
                    dataset_config, model_config, eval_config, AdversarialTrainingConfig(**{**asdict(adv_train_config), "attacker_iters": 1, "n_rounds": 5}), print_against=False
                )
            attacker_save_path = f"models/adv/{attacker_setup_name}/attacker_against_{alg_against}_{seed}.pt"
            attacker.load_state_dict(torch.load(attacker_save_path))
            print(f"Loaded attacker from '{attacker_save_path}'.")

            # Darkroom specific:
            optimal_actions = torch.zeros((n_envs, square_len, square_len, n_actions), device=device)

            for env in tqdm(range(n_envs), desc=(f"Getting Optimal Actions")):
                for state in all_states:
                    optimal_actions[env, state[0], state[1]] = get_optimal_action(state, attacker._original_goals[env, -1])

            env = DarkroomEnv(attacker._original_reward_map, n_envs, n_steps_eval, square_len, attacker._original_goals, optimal_actions, device=device)
            # optimal_policy = MDPOptimalController(env.optimal_actions, n_envs, n_steps_eval, n_states, state_dim, n_actions, device=device)

            is_rl_alg = victim_alg in ["ppo", "npg", "ql"]
            if is_rl_alg:
                for _ in tqdm(range(online_alg_episodes), desc=f"Learning Online - {victim_alg.upper()}"):
                    dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)

                    victim_policy.update(dataset_victim, adv_train_config)

            dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
            rewards_victim_against[alg_against].append(dataset_victim.rewards_original.numpy(force=True))

        attacker = MDPGridRandomAttacker(n_envs, square_len, adv_train_config.max_poison_diff, device=device)

        dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
        rewards_victim_against["unifrand"].append(dataset_victim.rewards_original.numpy(force=True))

        dataset_victim = env.deploy(victim_policy)
        rewards_victim_against["clean"].append(dataset_victim.rewards_original.numpy(force=True))

    algs_against += ["unifrand", "clean"]

    print(f"{victim_alg} reward against:")
    for alg_against in algs_against:
        print("{", f"{alg_against: ^17}", end="}", sep="")
    print()
    for alg_against in algs_against:
        rewards_alg: list[float] = [arr.sum(-1).mean(-1) for arr in rewards_victim_against[alg_against]]

        mean = np.mean(rewards_alg)
        confidence = 2 * stats.sem(rewards_alg)

        print("{", f"{mean:.1f} $\\pm$ {confidence:.1f}", end="}", sep="")
    print()

    results_filename = f"models/adv/{setup_name}/attacker_against_all_{victim_alg}_evals_seeds{n_seeds_config.n_seeds}.pkl"
    with open(results_filename, "wb") as f:
        pickle.dump(rewards_victim_against, f)
    print(f"Saved to '{results_filename}'.")


if __name__ == "__main__":
    dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config = parse_args_to_dataclass(
        (DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig, NSeedsConfig)
    )

    print(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config, sep="\n")

    time_start = time.time()
    main(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
