import os
import time
from dataclasses import asdict

import numpy as np
import torch

from args import (
    AdaptiveAttackerConfig,
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    LoggingConfig,
    ModelConfig,
    SeedConfig,
    get_adaptive_adv_trained_model_name,
    get_legacy_filename_config,
    parse_args_to_dataclass,
)
from bandit2.bandit_attacker import BanditAdaptiveAttacker, BanditAttacker
from bandit2.bandit_ctrl import BanditTransformerController
from bandit2.bandit_env import BanditEnv
from bandit_algs import get_bandit_algs
from net import Transformer
from util.logger import PrintLogger, WandbLogger
from util.seed import set_seed
from utils import build_bandit_model_filename

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(
    logging_config: LoggingConfig,
    seed_config: SeedConfig,
    dataset_config: DatasetConfig,
    model_config: ModelConfig,
    eval_config: EvalConfig,
    adv_train_config: AdversarialTrainingConfig,
    adaptive_attacker_config: AdaptiveAttackerConfig,
):
    global_step = 0
    attacker_step = 0
    victim_step = 0

    run_name = get_adaptive_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, adaptive_attacker_config)

    if logging_config.log == "wandb":
        logger = WandbLogger(
            run_name,
            config={
                **asdict(seed_config),
                **asdict(dataset_config),
                **asdict(model_config),
                **asdict(eval_config),
                **asdict(adv_train_config),
                **asdict(adaptive_attacker_config),
            },
        )
    else:
        logger = PrintLogger(run_name, "Step")

    ####################
    # Init env and attacker
    ####################
    set_seed(seed_config.seed)

    shuffle = True

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_actions = dataset_config.n_actions
    attacker_against = adv_train_config.attacker_against

    env = BanditEnv.sample(n_envs, n_steps, n_actions, dataset_config.variance, device=device)
    attacker: BanditAttacker = BanditAdaptiveAttacker(adaptive_attacker_config, n_envs, n_actions, env.original_means, device=device)

    ####################
    # Load DPT
    ####################
    transformer_config = model_config.get_params({"H": dataset_config.context_len, "state_dim": 1, "action_dim": dataset_config.n_actions})
    filename = build_bandit_model_filename("bandit", get_legacy_filename_config(model_config, dataset_config, seed_config))
    if eval_config.epoch is None:
        model_path = f"models/{filename}.pt"
    else:
        model_path = f"models/{filename}_epoch{eval_config.epoch}.pt"
    dpt_policy: BanditTransformerController = None  # type: ignore
    if "dpt" in attacker_against:
        model = Transformer(transformer_config).to(device)

        model.test = True
        model.load_state_dict(torch.load(model_path))
        model.eval()

        dpt_policy = BanditTransformerController(model, n_envs, n_steps, n_actions, sample=True, lr=adv_train_config.victim_lr, device=device)

    model_frozen = Transformer(transformer_config).to(device)
    model_frozen.test = True
    model_frozen.load_state_dict(torch.load(model_path))
    model_frozen.eval()
    print(f"Loaded model {model_path}.")
    dpt_frozen_policy = BanditTransformerController(model_frozen, n_envs, n_steps, n_actions, sample=True, frozen=True, device=device)

    ####################
    # Setup baseline policies
    ####################
    policies = get_bandit_algs(
        n_envs,
        n_steps,
        n_actions,
        env.get_optimal_actions(),
        dpt_policy,
        dpt_frozen_policy,
        adv_train_config.eps_steps,
        0.1,
        adv_train_config.eps_steps,
        adv_train_config.max_poison_diff,
        device=device,
    )
    victim = policies[attacker_against]

    # round_rewards: list[dict[str, np.ndarray]] = [] # todo

    ####################
    # Adv. training
    ####################
    for round in range(adv_train_config.n_rounds):
        print(f"Round {round+1}")

        start_time = time.time()

        dataset_opt = env.deploy(policies["opt"])
        dataset_victim = env.deploy(victim, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)

        if logging_config.debug == "logattperf":
            dataset_ts = env.deploy(policies["ts"], attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
            dataset_victim_noatt = env.deploy(victim)

            rewards_ts = dataset_ts.rewards_original.sum(-1).mean(-1)
            rewards_victim_noatt = dataset_victim_noatt.rewards_original.sum(-1).mean(-1)

        deploy_time = time.time() - start_time

        logger.log(
            {
                "round": round,
                "eval/gen_trajectories_time": deploy_time,
                f"eval/rewards_{attacker_against}": dataset_victim.rewards_original.sum().item() / n_envs,
                "eval/rewards_opt": dataset_opt.rewards_original.sum().item() / n_envs,
                **({"eval/rewards_ts": rewards_ts, f"eval/rewards_{attacker_against}_noatt": rewards_victim_noatt} if logging_config.debug == "logattperf" else {}),
            },
            step=global_step,
        )

        metrics = attacker.update(dataset_victim, adv_train_config)
        for metrics_item in metrics:
            logger.log({"round": round, "train/attacker_step": attacker_step, **metrics_item}, step=global_step)
            attacker_step += 1
            global_step += 1

        dataset_victim.shuffle = shuffle
        metrics, _ = victim.update(dataset_victim, adv_train_config)
        for metrics_item in metrics:
            logger.log({"round": round, "train/victim_step": victim_step, **metrics_item}, step=global_step)
            victim_step += 1
            global_step += 1

        should_save = adv_train_config.n_rounds > 100 and (round + 1) % 50 == 0
        if should_save or round == adv_train_config.n_rounds - 1:
            ####################
            # Save attacker & AT-DPT
            ####################
            modified_adv_train_config = AdversarialTrainingConfig(**{**asdict(adv_train_config), "n_rounds": round + 1})
            setup_name = get_adaptive_adv_trained_model_name(dataset_config, model_config, eval_config, modified_adv_train_config, adaptive_attacker_config, print_against=False)
            os.makedirs(f"models/adv/{setup_name}", exist_ok=True)
            attacker_save_path = f"models/adv/{setup_name}/attacker_against_{attacker_against}_{seed_config.seed}.pt"
            torch.save(attacker.state_dict(), attacker_save_path)
            print(f"Saved attacker to '{attacker_save_path}'.")

            if attacker_against == "dpt":
                atdpt_save_path = f"models/adv/{setup_name}/atdpt_{seed_config.seed}.pt"
                torch.save(model.state_dict(), atdpt_save_path)
                print(f"Saved AT-DPT to '{atdpt_save_path}'.")


if __name__ == "__main__":
    logging_config, seed_config, dataset_config, model_config, eval_config, adv_train_config, adaptive_attacker_config = parse_args_to_dataclass(
        (LoggingConfig, SeedConfig, DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig, AdaptiveAttackerConfig)
    )

    print(logging_config, seed_config, dataset_config, model_config, eval_config, adv_train_config, adaptive_attacker_config, sep="\n")

    time_start = time.time()
    main(logging_config, seed_config, dataset_config, model_config, eval_config, adv_train_config, adaptive_attacker_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
