import os
import pickle
import time
from dataclasses import asdict

import numpy as np
import torch

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    LoggingConfig,
    ModelConfig,
    SeedConfig,
    get_adv_trained_model_name,
    get_legacy_filename_config,
    parse_args_to_dataclass,
)
from bandit2.bandit_attacker import BanditAttacker, BanditUniformRandomAttacker
from bandit2.bandit_ctrl import BanditTransformerController
from bandit2.bandit_env import BanditEnv
from bandit_algs import get_bandit_algs
from net import Transformer
from util.logger import PrintLogger, WandbLogger
from util.seed import set_seed
from utils import build_bandit_model_filename

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(
    logging_config: LoggingConfig,
    seed_config: SeedConfig,
    dataset_config: DatasetConfig,
    model_config: ModelConfig,
    eval_config: EvalConfig,
    adv_train_config: AdversarialTrainingConfig,
):
    global_step = 0
    attacker_step = 0
    victim_step = 0

    run_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config)

    if logging_config.log == "wandb":
        logger = WandbLogger(
            run_name,
            config={
                **asdict(seed_config),
                **asdict(dataset_config),
                **asdict(model_config),
                **asdict(eval_config),
                **asdict(adv_train_config),
            },
        )
    else:
        logger = PrintLogger(run_name, "Step")

    ####################
    # Init env and attacker
    ####################
    set_seed(seed_config.seed)

    shuffle = True

    n_envs = eval_config.n_envs_eval
    n_steps = dataset_config.context_len
    n_actions = dataset_config.n_actions
    attacker_against = adv_train_config.attacker_against

    env = BanditEnv.sample(n_envs, n_steps, n_actions, dataset_config.variance, device=device)
    attacker: BanditAttacker = BanditUniformRandomAttacker(n_envs, n_actions, dataset_config.variance, max_poison_diff=adv_train_config.max_poison_diff, device=device)

    ####################
    # Load DPT
    ####################
    transformer_config = model_config.get_params({"H": dataset_config.context_len, "state_dim": 1, "action_dim": dataset_config.n_actions})
    filename = build_bandit_model_filename("bandit", get_legacy_filename_config(model_config, dataset_config, seed_config))
    if eval_config.epoch is None:
        model_path = f"models/{filename}.pt"
    else:
        model_path = f"models/{filename}_epoch{eval_config.epoch}.pt"
    dpt_policy: BanditTransformerController = None  # type: ignore
    if "dpt" in attacker_against:
        model = Transformer(transformer_config).to(device)

        model.test = True
        model.load_state_dict(torch.load(model_path))
        model.eval()

        dpt_policy = BanditTransformerController(model, n_envs, n_steps, n_actions, sample=True, lr=adv_train_config.victim_lr, device=device)

    model_frozen = Transformer(transformer_config).to(device)
    model_frozen.test = True
    model_frozen.load_state_dict(torch.load(model_path))
    model_frozen.eval()
    print(f"Loaded model {model_path}.")
    dpt_frozen_policy = BanditTransformerController(model_frozen, n_envs, n_steps, n_actions, sample=True, frozen=True, device=device)

    ####################
    # Setup baseline policies
    ####################
    rts_corruption_level_known = n_steps * adv_train_config.eps_steps * adv_train_config.max_poison_diff * (1 / n_actions)
    rts_corruption_level_unknown = torch.sqrt(n_steps * torch.log(torch.tensor(n_actions)) / n_actions).item()
    rts_corruption_level_tuned = 0.5
    print(f"Robust TS corruption levels: {rts_corruption_level_known}, {rts_corruption_level_unknown}, {rts_corruption_level_tuned}")

    policies = get_bandit_algs(
        n_envs,
        n_steps,
        n_actions,
        env.get_optimal_actions(),
        dpt_policy,
        dpt_frozen_policy,
        adv_train_config.eps_steps,
        0.1,
        adv_train_config.eps_steps,
        adv_train_config.max_poison_diff,
        device=device,
    )
    victim = policies[attacker_against]

    round_rewards: list[dict[str, np.ndarray]] = []

    ####################
    # Adv. training
    ####################
    for round in range(adv_train_config.n_rounds):
        print(f"Round {round+1}")

        start_time = time.time()

        dataset_opt = env.deploy(policies["opt"], attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
        dataset_victim = env.deploy(victim, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)

        deploy_time = time.time() - start_time

        if adv_train_config.log_round_rewards and attacker_against == "dpt":
            dataset_dpt_frozen = env.deploy(dpt_frozen_policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
            dataset_ts = env.deploy(policies["ts"], attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)

            dataset_dpt_frozen_noatt = env.deploy(dpt_frozen_policy)
            round_rewards.append(
                {
                    "opt": dataset_opt.rewards_original.numpy(force=True),
                    attacker_against: dataset_victim.rewards_original.numpy(force=True),
                    "dpt_frozen": dataset_dpt_frozen.rewards_original.numpy(force=True),
                    "dpt_frozen_noatt": dataset_dpt_frozen_noatt.rewards_original.numpy(force=True),
                    "ts": dataset_ts.rewards_original.numpy(force=True),
                }
            )

        logger.log(
            {
                "round": round,
                "eval/gen_trajectories_time": deploy_time,
                f"eval/rewards_{attacker_against}": dataset_victim.rewards_original.sum().item() / n_envs,
                "eval/rewards_opt": dataset_opt.rewards_original.sum().item() / n_envs,
            },
            step=global_step,
        )

        metrics = attacker.update(dataset_victim, adv_train_config)
        # train_attacker_reinforce(rewards_victim, dataset, vec_env, attacker, attacker_optimizer, adv_train_config)
        for metrics_item in metrics:
            logger.log({"round": round, "train/attacker_step": attacker_step, **metrics_item}, step=global_step)
            attacker_step += 1
            global_step += 1

        dataset_victim.shuffle = shuffle
        metrics, _ = victim.update(dataset_victim, adv_train_config)
        for metrics_item in metrics:
            logger.log({"round": round, "train/victim_step": victim_step, **metrics_item}, step=global_step)
            victim_step += 1
            global_step += 1

    setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=False)

    os.makedirs(f"models/adv/{setup_name}", exist_ok=True)
    # attacker_save_path = f"models/adv/{setup_name}/attacker_against_{attacker_against}_{seed_config.seed}.pt"
    # torch.save(attacker.state_dict(), attacker_save_path)
    # print(f"Saved attacker to '{attacker_save_path}'.")

    if attacker_against == "dpt":
        atdpt_save_path = f"models/adv/{setup_name}/atdpt_againstunifrand_{seed_config.seed}.pt"
        torch.save(model.state_dict(), atdpt_save_path)
        print(f"Saved AT-DPT to '{atdpt_save_path}'.")

    ####################
    # Evaluate attacker on baselines
    ####################
    print("Evaluating...")
    eval_rewards: dict[str, np.ndarray] = {}
    algs_eval = ["dpt_frozen", "ts", "ucb"]

    opt_dataset = env.deploy(policies["opt"], attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
    opt_rew = opt_dataset.rewards_original.sum(-1).mean(-1)

    for alg in algs_eval:
        policy = policies[alg]
        if alg == "dpt":
            # TODO: check if can load maybe
            continue
        dataset = env.deploy(policy, attacker, adv_train_config.eps_episodes, adv_train_config.eps_steps)
        eval_rewards[alg] = dataset.rewards_original.numpy(force=True)

    for alg in algs_eval:
        print("{" + f"{alg}", end="}")
    print()

    for alg in algs_eval:
        regret = opt_rew - eval_rewards[alg].sum(-1).mean(-1)
        print("{" + f"{regret:.2f}", end="}")

    # evals_path = f"models/adv/{setup_name}/attacker_against_{attacker_against}_{seed_config.seed}_evals.pkl"
    # with open(evals_path, "wb") as f:
    #     pickle.dump(eval_rewards, f)
    # print(f"Saved to '{evals_path}'.")

    ####################
    # Log round rewards
    ####################
    if adv_train_config.log_round_rewards:
        results_filename = f"models/adv/{setup_name}/attacker_unifrand_against_{attacker_against}_{seed_config.seed}_round_rewards.pkl"
        with open(results_filename, "wb") as f:
            pickle.dump(round_rewards, f)
        print(f"Saved round rewards to '{results_filename}'.")


if __name__ == "__main__":
    logging_config, seed_config, dataset_config, model_config, eval_config, adv_train_config = parse_args_to_dataclass(
        (LoggingConfig, SeedConfig, DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig)
    )

    print(logging_config, seed_config, dataset_config, model_config, eval_config, adv_train_config, sep="\n")

    time_start = time.time()
    main(logging_config, seed_config, dataset_config, model_config, eval_config, adv_train_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
