import argparse
import os
import pickle
import time

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

from args import (
    AdversarialTrainingConfig,
    DatasetConfig,
    EvalConfig,
    ModelConfig,
    NSeedsConfig,
    get_adv_trained_model_name,
    parse_args_to_dataclass,
)

plt.rcParams.update({"font.size": 14})
plt.rcParams["svg.fonttype"] = "none"
plt.rcParams["font.family"] = "Latin Modern Math"


def main(
    n_seeds_config: NSeedsConfig,
    dataset_config: DatasetConfig,
    model_config: ModelConfig,
    eval_config: EvalConfig,
    adv_train_config: AdversarialTrainingConfig,
):
    n_rounds = adv_train_config.n_rounds

    setup_name = get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=False)
    results_dir = f"models/adv/{setup_name}"

    plottables: list[tuple[str, str, str, str | None]] = [
        ("dpt", "tab:blue", "DPT", None),
        ("dpt_frozen", "tab:blue", "DPT frozen", "--"),
        ("dpt_frozen_noatt", "tab:green", "DPT frozen (no attack)", "--"),
        ("ts", "tab:orange", "TS", None),
    ]

    alg_regrets = {alg: np.zeros((n_rounds, n_seeds_config.n_seeds)) for alg, _, _, _ in plottables}

    for seed in range(n_seeds_config.n_seeds):
        with open(f"{results_dir}/attacker_against_{adv_train_config.attacker_against}_{seed}_round_rewards.pkl", "rb") as f:
            round_rewards: list[dict[str, np.ndarray]] = pickle.load(f)

        for round in range(n_rounds):
            opt_rewards: np.ndarray = round_rewards[round]["opt"]
            for alg, _, _, _ in plottables:
                alg_rewards = round_rewards[round][alg]
                alg_regrets[alg][round, seed] = (opt_rewards - alg_rewards).sum(-1).mean(-1)

    for alg, color, label, linestyle in plottables:
        regrets = alg_regrets[alg]

        # if alg == "dpt":
        #     for run in range(regrets.shape[0]):
        #         plt.plot(range(n_rounds), regrets[run], alpha=0.2)

        means = regrets.mean(axis=-1)
        confidence = 2 * stats.sem(regrets, axis=-1)

        plt.plot(range(n_rounds), means, linestyle=linestyle, label=label, color=color)
        plt.fill_between(range(n_rounds), means - confidence, means + confidence, color=color, alpha=0.1)

    plt.gcf().set_size_inches(5, 3)
    plt.ylabel("Cumulative Regret")
    plt.gca().set_box_aspect(1)
    plt.legend(loc="upper left", bbox_to_anchor=(1, 1.03))
    plt.grid(alpha=0.15)
    plt.xlabel("Rounds")

    plt.gcf().set_dpi(150)
    plt.tight_layout()

    os.makedirs(f"figs/{setup_name}", exist_ok=True)
    output_dir = f"figs/{setup_name}/adv_rounds_against{adv_train_config.attacker_against}_seeds{n_seeds_config.n_seeds}.svg"
    plt.savefig(output_dir)
    print(f"Saved to '{output_dir}'.")


if __name__ == "__main__":
    n_seeds_config, dataset_config, model_config, eval_config, adv_train_config = parse_args_to_dataclass(
        (NSeedsConfig, DatasetConfig, ModelConfig, EvalConfig, AdversarialTrainingConfig)
    )

    print(n_seeds_config, dataset_config, model_config, eval_config, adv_train_config, sep="\n")

    time_start = time.time()
    main(n_seeds_config, dataset_config, model_config, eval_config, adv_train_config)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
