import pickle
import time

import gymnasium as gym
import miniworld
import numpy as np
import scipy.stats as stats
import torch
from tqdm import tqdm

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    ModelConfig,
    NSeedsConfig,
    get_adv_trained_model_name,
    get_model_name,
    parse_args_to_dataclass,
)
from mdp.mdp_attacker import (
    MDPAttacker,
    MDPGridClassifierAttacker,
    MDPGridRandomAttacker,
)
from mdp.mdp_controller import MDPImageTransformerController
from mdp.miniworld_env import MDPMiniworldAttacker, MiniworldEnv
from mdp_algs import get_mdp_algs
from net import ImageTransformer
from util.argparser_dataclass import parse_args_to_dataclass
from util.seed import set_seed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(dataset_config: DatasetConfig, model_config: ModelConfig, eval_config: EvalConfig, adv_train_config: AdversarialTrainingConfig, n_seeds_config: NSeedsConfig) -> None:
    algs_against = ["dpt", "dpt_frozen"]  # , "qlfa"]

    online_alg_episodes = 100  # hardcoded for now

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_steps_eval = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len
    n_states = dataset_config.n_states
    n_actions = dataset_config.n_actions
    victim_alg = adv_train_config.attacker_against

    state_dim = 2

    setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=False)

    rewards_victim_against: dict[str, list[np.ndarray]] = {alg: [] for alg in algs_against + ["unifrand", "clean"]}

    model_name = get_model_name(dataset_config, model_config)
    for seed in range(n_seeds_config.n_seeds):
        set_seed(seed)
        dpt_policy = None
        dpt_frozen_policy = None
        if victim_alg == "dpt":
            atdpt_seed = (seed + 1) % n_seeds_config.n_seeds  # to prevent atdpt being evaluated on same attacker
            atdpt_path = f"models/adv/{setup_name}/atdpt_{atdpt_seed}.pt"
            model = ImageTransformer(model_config.get_params({"H": n_steps, "state_dim": state_dim, "action_dim": n_actions, "image_size": 25})).to(device)
            model.test = True
            model.load_state_dict(torch.load(atdpt_path))
            model.eval()
            print(f"Loaded model {atdpt_path}.")

            dpt_policy = MDPImageTransformerController(model, n_envs, n_steps, n_states, state_dim, n_actions, sample=True, lr=adv_train_config.victim_lr, device=device)
        elif victim_alg == "dpt_frozen":
            # Temporary for miniworld
            context_len = dataset_config.context_len
            epoch = eval_config.epoch
            legacy_model_name = f"miniworld_shufTrue_lr0.0001_do0_embd32_layer4_head4_envs60000_hists1_samples1_H{context_len}_seed{seed}_epoch{epoch}"
            # get_model_save_name(model_name, SeedConfig(seed), model_config.n_epochs, eval_config.epoch)

            model_path = f"models/{legacy_model_name}.pt"
            model = ImageTransformer(model_config.get_params({"H": n_steps, "state_dim": state_dim, "action_dim": n_actions, "image_size": 25})).to(device)
            model.test = True
            model.load_state_dict(torch.load(model_path))
            model.eval()
            print(f"Loaded model {model_path}.")

            dpt_frozen_policy = MDPImageTransformerController(model, n_envs, n_steps, n_states, state_dim, n_actions, sample=True, frozen=True, device=device)

        policies = get_mdp_algs(
            n_envs,
            n_steps,
            n_steps_eval,
            0,
            state_dim,
            n_actions,
            None,
            dpt_policy,
            dpt_frozen_policy,
            device=device,
        )

        victim_policy: MDPImageTransformerController = policies[victim_alg]  # type: ignore

        attacker: MDPAttacker | None = None
        for alg_against in algs_against:
            dummy_env = MiniworldEnv.sample(n_envs, n_steps_eval, device=device, seed=60000 + seed)
            attacker = MDPMiniworldAttacker(dummy_env.task_ids, n_envs, lr=adv_train_config.attacker_lr, device=device)
            attacker_setup_name = setup_name
            # if alg_against == "ql":
            #     # FIXME: temporary for ql
            #     attacker_setup_name = get_adv_trained_model_name(
            #         dataset_config, model_config, eval_config, AdversarialTrainingConfig(**{**asdict(adv_train_config), "attacker_iters": 1, "n_rounds": 5}), print_against=False
            #     )
            # elif alg_against == "npg":
            #     # FIXME: temporary for npg
            #     attacker_setup_name = get_adv_trained_model_name(
            #         dataset_config, model_config, eval_config, AdversarialTrainingConfig(**{**asdict(adv_train_config), "attacker_iters": 1, "n_rounds": 5}), print_against=False
            #     )
            attacker_save_path = f"models/adv/{attacker_setup_name}/attacker_against_{alg_against}_{seed}.pt"
            attacker.load_state_dict(torch.load(attacker_save_path))
            print(f"Loaded attacker from '{attacker_save_path}'.")

            envs_inner = []
            for task_id in attacker._original_task_ids:
                env_inner = gym.make("MiniWorld-OneRoomS6FastMultiFourBoxesFixedInit-v0", max_episode_steps=n_steps)  # type: ignore
                env_inner.unwrapped.set_task(env_id=int(task_id.item()))  # type: ignore
                envs_inner.append(env_inner)

            env = MiniworldEnv(envs_inner, n_steps_eval, device=device)
            # optimal_policy = MDPOptimalController(env.optimal_actions, n_envs, n_steps_eval, n_states, state_dim, n_actions, device=device)

            is_rl_alg = victim_alg in ["ppo", "npg", "ql", "qlfa"]
            if is_rl_alg:
                for _ in tqdm(range(online_alg_episodes), desc=f"Learning Online - {victim_alg.upper()}"):
                    dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps, force_show_progress=True)

                    victim_policy.update(dataset_victim, adv_train_config)

            dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps, force_show_progress=True)
            rewards_victim_against[alg_against].append(dataset_victim.rewards_original.numpy(force=True))

        attacker = MDPGridRandomAttacker(n_envs, 7, adv_train_config.max_poison_diff, device=device)

        dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps, force_show_progress=True)
        rewards_victim_against["unifrand"].append(dataset_victim.rewards_original.numpy(force=True))

        dataset_victim = env.deploy(victim_policy, force_show_progress=True)
        rewards_victim_against["clean"].append(dataset_victim.rewards_original.numpy(force=True))

    algs_against += ["unifrand", "clean"]

    print(f"{victim_alg} reward against:")
    for alg_against in algs_against:
        print("{", f"{alg_against: ^17}", end="}", sep="")
    print()
    for alg_against in algs_against:
        rewards_alg: list[float] = [arr.sum(-1).mean(-1) for arr in rewards_victim_against[alg_against]]

        mean = np.mean(rewards_alg)
        confidence = 2 * stats.sem(rewards_alg)

        print("{", f"{mean:.1f} $\\pm$ {confidence:.1f}", end="}", sep="")
    print()

    results_filename = f"models/adv/{setup_name}/attacker_against_all_{victim_alg}_evals_seeds{n_seeds_config.n_seeds}.pkl"
    with open(results_filename, "wb") as f:
        pickle.dump(rewards_victim_against, f)
    print(f"Saved to '{results_filename}'.")


if __name__ == "__main__":
    dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config = parse_args_to_dataclass(
        (DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig, NSeedsConfig)
    )

    print(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config, sep="\n")

    time_start = time.time()
    main(dataset_config, model_config, eval_config, adversarial_training_config, n_seeds_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
