import os
import time
from dataclasses import asdict
from math import sqrt

import torch

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    LoggingConfig,
    ModelConfig,
    SeedConfig,
    get_adv_trained_model_name,
    get_legacy_miniworld_config,
    get_model_name,
    get_model_save_name,
)
from mdp.darkroom_env import DarkroomEnv
from mdp.mdp_attacker import MDPGridClassifierAttacker
from mdp.mdp_controller import MDPImageTransformerController, MDPTransformerController
from mdp.miniworld_env import MDPMiniworldAttacker, MiniworldEnv
from mdp_algs import get_mdp_algs
from net import ImageTransformer, Transformer
from util.argparser_dataclass import parse_args_to_dataclass
from util.logger import PrintLogger, WandbLogger
from util.seed import set_seed
from utils import build_miniworld_model_filename

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# torch.set_default_device(device)


def main(
    logging_config: LoggingConfig,
    seed_config: SeedConfig,
    dataset_config: DatasetConfig,
    model_config: ModelConfig,
    eval_config: EvalConfig,
    adv_train_config: AdversarialTrainingConfig,
):

    run_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config)

    if logging_config.log == "wandb":
        logger = WandbLogger(
            run_name,
            config={
                **asdict(dataset_config),
                **asdict(model_config),
                **asdict(eval_config),
                **asdict(adv_train_config),
            },
        )
    else:
        logger = PrintLogger(run_name, "Step")

    if logging_config.debug is None:
        logging_config.debug = ""

    ####################
    # Init env and attacker
    ####################
    set_seed(seed_config.seed)

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_steps_eval = eval_config.n_steps_eval if eval_config.n_steps_eval is not None else dataset_config.context_len
    n_states = dataset_config.n_states
    eps_episodes = adv_train_config.eps_episodes
    eps_steps = adv_train_config.eps_steps
    attacker_against = adv_train_config.attacker_against

    if dataset_config.env == "chain":
        state_dim = 1
        n_actions = 2
        raise NotImplementedError()
        # env = ChainEnv.sample(n_envs, n_steps_eval, n_states, dataset_config.variance, device=device)
    elif dataset_config.env == "darkroom":
        state_dim = 2
        n_actions = 5
        square_len = int(sqrt(dataset_config.n_states))
        env = DarkroomEnv.sample(n_envs, n_steps_eval, square_len, device=device)

        attacker = MDPGridClassifierAttacker(env.reward_map, env.goals, n_envs, square_len, lr=adv_train_config.attacker_lr, device=device)
    elif dataset_config.env == "miniworld":
        state_dim = 2
        n_actions = 4
        env = MiniworldEnv.sample(n_envs, n_steps_eval, device=device)

        attacker = MDPMiniworldAttacker(env.task_ids, n_envs, lr=adv_train_config.attacker_lr, device=device)
    else:
        raise NotImplementedError()

    ####################
    # Load DPT
    ####################
    if dataset_config.env != "miniworld":
        model_name = get_model_name(dataset_config, model_config)
        model_path = f"models/{get_model_save_name(model_name, seed_config, model_config.n_epochs, eval_config.epoch)}.pt"
    else:
        legacy_config = get_legacy_miniworld_config(model_config, dataset_config, seed_config)
        if model_config.dropout == 0.0:
            legacy_config["dropout"] = int(0)
        legacy_name = build_miniworld_model_filename(dataset_config.env, legacy_config)
        if eval_config.epoch is None:
            model_path = f"models/{legacy_name}.pt"
        else:
            model_path = f"models/{legacy_name}_epoch{eval_config.epoch}.pt"

    ModelClass = Transformer if dataset_config.env != "miniworld" else ImageTransformer
    ControllerClass = MDPTransformerController if dataset_config.env != "miniworld" else MDPImageTransformerController
    model_params = model_config.get_params({"H": n_steps, "state_dim": state_dim, "action_dim": n_actions, "image_size": 25})

    dpt_policy: ControllerClass = None  # type: ignore

    if "dpt" in attacker_against:
        model = ModelClass(model_params).to(device)
        model.test = True
        model.load_state_dict(torch.load(model_path, weights_only=True))
        model.eval()

        dpt_policy = ControllerClass(model, n_envs, n_steps, n_states, state_dim, n_actions, sample=True, lr=adv_train_config.victim_lr, device=device)

    model_frozen = ModelClass(model_params).to(device)
    model_frozen.test = True
    model_frozen.load_state_dict(torch.load(model_path, weights_only=True))
    model_frozen.eval()
    print(f"Loaded model {model_path}.")

    dpt_frozen_policy = ControllerClass(model_frozen, n_envs, n_steps, n_states, state_dim, n_actions, frozen=True, sample=True, device=device)

    ####################
    # Setup baseline policies
    ####################
    policies = get_mdp_algs(
        n_envs,
        n_steps,
        n_steps_eval,
        n_states,
        state_dim,
        n_actions,
        env.optimal_actions,
        dpt_policy,
        dpt_frozen_policy,
        device=device,
    )

    victim = policies[attacker_against]

    is_rl_alg = attacker_against in ["npg", "ql", "ppo"]
    n_episodes = 100

    ####################
    # Adv. training
    ####################
    global_step = 0
    attacker_step = 0
    victim_step = 0

    for round in range(adv_train_config.n_rounds):
        print(f"Round {round+1}")
        start_time = time.time()

        datasets = []
        if is_rl_alg:
            victim.reinitialize()
            for episode in range(n_episodes):
                with torch.no_grad():
                    dataset = env.deploy(victim, attacker, eps_episodes, eps_steps, context_len=dataset_config.context_len)
                    datasets.append(dataset)

                metrics, _ = victim.update(dataset, adv_train_config)
                # for metrics_item in metrics:
                #     logger.log(
                #         {
                #             "round": round,
                #             "train/victim_step": victim_step,
                #             "train/episode": round * n_episodes + episode,
                #             f"eval/rewards_{attacker_against}": dataset.rewards_original.sum(-1).mean(-1),
                #             **metrics_item,
                #         },
                #         step=global_step,
                #     )
                #     victim_step += 1
                #     global_step += 1
        log_items = {}

        with torch.no_grad():
            save_video = False
            dataset = env.deploy(victim, attacker, eps_episodes, eps_steps, context_len=dataset_config.context_len, save_video=save_video, force_show_progress=True)
            datasets.append(dataset)

            if "logbaselines" in logging_config.debug:
                dataset_frozen = env.deploy(policies["dpt_frozen"], attacker, eps_episodes, eps_steps, force_show_progress=True)
                dataset_frozen_noatt = env.deploy(policies["dpt_frozen"], force_show_progress=True)

                rewards_frozen_total = dataset_frozen.rewards_original.sum().item() / n_envs
                rewards_frozen_noatt_total = dataset_frozen_noatt.rewards_original.sum().item() / n_envs

                if dataset_config.env == "miniworld":
                    dataset_rand = env.deploy(policies["img_rand"], force_show_progress=True)
                    rewards_rand_total = dataset_rand.rewards_original.sum().item() / n_envs
                    log_items["eval/rewards_rand"] = rewards_rand_total

            end_time = time.time()

            if "showvideo" in logging_config.debug:
                anim = env.visualize_dataset(dataset, attacker_weights=attacker.weights.data)  # type: ignore
                anim.save("test.mp4")
                input("Input to continue")
                # break

            rewards_victim_total = dataset.rewards_original.sum().item() / n_envs

            if dataset_config.env != "miniworld":
                dataset_opt = env.deploy(policies["opt"], force_show_progress=True)
                rewards_opt_total = dataset_opt.rewards_original.sum().item() / n_envs
                log_items["eval/rewards_opt"] = rewards_opt_total

        logger.log(
            {
                "round": round,
                "eval/gen_trajectories_time": end_time - start_time,
                f"eval/rewards_{attacker_against}": rewards_victim_total,
                **log_items,
                **(
                    {"eval/rewards_dpt_frozen": rewards_frozen_total, "eval/rewards_dpt_frozen_noatt": rewards_frozen_noatt_total} if logging_config.debug == "logbaselines" else {}
                ),
            },
            step=global_step,
        )

        metrics = attacker.update(datasets, adv_train_config)  # type: ignore
        for metrics_item in metrics:
            logger.log({"round": round, "train/attacker_step": attacker_step, **metrics_item}, step=global_step)
            attacker_step += 1
            global_step += 1

        metrics, _ = victim.update(dataset, adv_train_config)
        for metrics_item in metrics:
            logger.log({"round": round, "train/victim_step": victim_step, **metrics_item}, step=global_step)
            victim_step += 1
            global_step += 1

        should_save = adv_train_config.n_rounds > 100 and (round + 1) % 50 == 0
        should_save = should_save or (adv_train_config.attacker_against == "npg" and round + 1 == 5)
        should_save = should_save or (adv_train_config.attacker_against == "npg" and round + 1 == 10)
        should_save = should_save or (adv_train_config.attacker_against == "npg" and round + 1 == 50)
        should_save = should_save or (adv_train_config.attacker_against == "ql" and round + 1 == 5)
        should_save = should_save or (adv_train_config.attacker_against == "ql" and round + 1 == 10)
        should_save = should_save or (adv_train_config.attacker_against == "ql" and round + 1 == 20)
        if should_save or round == adv_train_config.n_rounds - 1:
            ####################
            # Save attacker & AT-DPT
            ####################
            modified_adv_train_config = AdversarialTrainingConfig(**{**asdict(adv_train_config), "n_rounds": round + 1})

            setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, modified_adv_train_config, print_against=False)

            os.makedirs(f"models/adv/{setup_name}", exist_ok=True)
            attacker_save_path = f"models/adv/{setup_name}/attacker_against_{attacker_against}_{seed_config.seed}.pt"
            torch.save(attacker.state_dict(), attacker_save_path)
            print(f"Saved attacker to '{attacker_save_path}'.")

            if attacker_against == "dpt":
                atdpt_save_path = f"models/adv/{setup_name}/atdpt_{seed_config.seed}.pt"
                torch.save(model.state_dict(), atdpt_save_path)
                print(f"Saved AT-DPT to '{atdpt_save_path}'.")

    ####################
    # Evaluate attacker on baselines
    ####################
    # print("Evaluating...")
    # eval_rewards: dict[str, np.ndarray] = {}
    # for alg, policy in policies.items():
    #     if alg in ["dpt", "npg", "ppo"]:
    #         continue
    #     dataset = env.deploy(policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
    #     eval_rewards[alg] = dataset.rewards_original.numpy(force=True)

    # evals_path = f"models/adv/{setup_name}/attacker_against_{attacker_against}_{seed_config.seed}_evals.pkl"
    # with open(evals_path, "wb") as f:
    #     pickle.dump(eval_rewards, f)
    # print(f"Saved to '{evals_path}'.")


if __name__ == "__main__":
    logging_config, seed_config, dataset_config, model_config, eval_config, adv_train_config = parse_args_to_dataclass(
        (LoggingConfig, SeedConfig, DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig)
    )

    print(logging_config, seed_config, dataset_config, model_config, eval_config, adv_train_config, sep="\n")

    time_start = time.time()
    main(logging_config, seed_config, dataset_config, model_config, eval_config, adv_train_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
