import pickle
import time
from dataclasses import asdict
from math import sqrt

import torch
from tqdm import tqdm

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    ModelConfig,
    SeedConfig,
    get_adv_trained_model_name,
    get_model_name,
    get_model_save_name,
    parse_args_to_dataclass,
)
from mdp.darkroom_env import DarkroomEnv, get_optimal_action
from mdp.mdp_attacker import MDPGridClassifierAttacker, MDPGridRandomAttacker
from mdp.mdp_controller import MDPTransformerController
from mdp_algs import get_mdp_algs
from net import Transformer
from util.argparser_dataclass import parse_args_to_dataclass
from util.seed import set_seed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(seed_config: SeedConfig, dataset_config: DatasetConfig, model_config: ModelConfig, eval_config: EvalConfig, adv_train_config: AdversarialTrainingConfig) -> None:
    online_alg_episodes = 100  # hardcoded for now

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_steps_eval = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len
    n_actions = dataset_config.n_actions

    alg_victim, alg_att = adv_train_config.attacker_against.split("-")

    if dataset_config.env == "chain":
        state_dim = 1
        n_actions = 2
        raise NotImplementedError()
        # env = ChainEnv.sample(n_envs, n_steps_eval, n_states, dataset_config.variance, device=device)
    elif dataset_config.env == "darkroom":
        state_dim = 2
        n_actions = 5
        square_len = int(sqrt(dataset_config.n_states))

        n_states = dataset_config.n_states
    else:
        raise NotImplementedError()

    all_states = torch.stack(
        [
            torch.arange(square_len, device=device).repeat_interleave(square_len),
            torch.arange(square_len, device=device).repeat(square_len),
        ],
        dim=1,
    )

    setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=False)

    model_name = get_model_name(dataset_config, model_config)

    seed = seed_config.seed

    set_seed(seed)

    ####################
    # Init policy controller
    ####################
    dpt_policy = None
    dpt_frozen_policy = None
    if alg_victim == "dpt":
        atdpt_seed = (seed + 1) % 10  # to prevent atdpt being evaluated on same attacker
        atdpt_path = f"models/adv/{setup_name}/atdpt_{atdpt_seed}.pt"
        model = Transformer(model_config.get_params({"H": n_steps, "state_dim": state_dim, "action_dim": n_actions})).to(device)
        model.test = True
        model.load_state_dict(torch.load(atdpt_path))
        model.eval()
        print(f"Loaded model {atdpt_path}.")

        dpt_policy = MDPTransformerController(model, n_envs, n_steps, n_states, state_dim, n_actions, sample=True, lr=adv_train_config.victim_lr, device=device)
    elif alg_victim == "dpt_frozen":
        model_path = f"models/{get_model_save_name(model_name, SeedConfig(seed), model_config.n_epochs, eval_config.epoch)}.pt"
        model = Transformer(model_config.get_params({"H": n_steps, "state_dim": state_dim, "action_dim": n_actions})).to(device)
        model.test = True
        model.load_state_dict(torch.load(model_path))
        model.eval()
        print(f"Loaded model {model_path}.")

        dpt_frozen_policy = MDPTransformerController(model, n_envs, n_steps, n_states, state_dim, n_actions, sample=True, frozen=True, device=device)

    policies = get_mdp_algs(
        n_envs,
        n_steps,
        n_steps_eval,
        n_states,
        state_dim,
        n_actions,
        None,
        dpt_policy,
        dpt_frozen_policy,
        device=device,
    )
    victim_policy = policies[alg_victim]

    ####################
    # Init attacker
    ####################

    if alg_att not in ["unifrand", "clean"]:
        dummy_env = DarkroomEnv.sample(n_envs, n_steps_eval, square_len, device=device)
        attacker = MDPGridClassifierAttacker(dummy_env.reward_map, dummy_env.goals, n_envs, n_actions, device=device)
        attacker_setup_name = setup_name
        if alg_att == "ql":
            # FIXME: temporary for ql
            attacker_setup_name = get_adv_trained_model_name(
                dataset_config, model_config, eval_config, AdversarialTrainingConfig(**{**asdict(adv_train_config), "attacker_iters": 1, "n_rounds": 5}), print_against=False
            )
        elif alg_att == "npg":
            # FIXME: temporary for npg
            attacker_setup_name = get_adv_trained_model_name(
                dataset_config, model_config, eval_config, AdversarialTrainingConfig(**{**asdict(adv_train_config), "attacker_iters": 1, "n_rounds": 5}), print_against=False
            )
        attacker_save_path = f"models/adv/{attacker_setup_name}/attacker_against_{alg_att}_{seed}.pt"
        attacker.load_state_dict(torch.load(attacker_save_path))
        print(f"Loaded attacker from '{attacker_save_path}'.")

        # Darkroom specific:
        optimal_actions = torch.zeros((n_envs, square_len, square_len, n_actions), device=device)

        for env in tqdm(range(n_envs), desc=(f"Getting Optimal Actions")):
            for state in all_states:
                optimal_actions[env, state[0], state[1]] = get_optimal_action(state, attacker._original_goals[env, -1])

        env = DarkroomEnv(attacker._original_reward_map, n_envs, n_steps_eval, square_len, attacker._original_goals, optimal_actions, device=device)
        # optimal_policy = MDPOptimalController(env.optimal_actions, n_envs, n_steps_eval, n_states, state_dim, n_actions, device=device)
    elif alg_att == "unifrand":
        env = DarkroomEnv.sample(n_envs, n_steps_eval, square_len, device=device)
        attacker = MDPGridRandomAttacker(n_envs, square_len, adv_train_config.max_poison_diff, device=device)
    elif alg_att == "clean":
        env = DarkroomEnv.sample(n_envs, n_steps_eval, square_len, device=device)
        attacker = None

    if alg_victim in ["ppo", "npg", "ql"]:
        for _ in tqdm(range(online_alg_episodes), desc=f"Learning Online - {alg_victim.upper()}"):
            dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)

            victim_policy.update(dataset_victim, adv_train_config)
    dataset_victim = env.deploy(victim_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)

    debug = False
    if debug:
        env.visualize_dataset(dataset_victim, attacker_weights=attacker.weights.data, env_idx=None)  # type: ignore

    rewards_victim = dataset_victim.rewards_original.numpy(force=True)

    reward_alg: float = rewards_victim.sum(-1).mean(-1)

    print(f"reward (seed={seed}): " + "{" + f"{reward_alg:.1f}" + "}")

    results_filename = f"models/adv/{setup_name}/attacker_against_{alg_att}_victim_{alg_victim}_{seed}_eval.pkl"
    with open(results_filename, "wb") as f:
        pickle.dump(rewards_victim, f)
    print(f"Saved to '{results_filename}'.")


if __name__ == "__main__":
    seed_config, dataset_config, model_config, eval_config, adversarial_training_config = parse_args_to_dataclass(
        (SeedConfig, DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig)
    )

    print(seed_config, dataset_config, model_config, eval_config, adversarial_training_config, sep="\n")

    time_start = time.time()
    main(seed_config, dataset_config, model_config, eval_config, adversarial_training_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
