import os
import glob
import argparse
from typing import List

import torch
import matplotlib.pyplot as plt
import seaborn as sns

from source.utils.metrics import mse, nll, crps
from source.constants import RESULTS_PATH_AL, PLOTS_PATH

sns.set_theme(style="whitegrid")

LABELS = {
    "total_1_1": "$\\hat{R}_{Tot}^{1,1}$",
    "total_2_1": "$\\hat{R}_{Tot}^{2,1}$",
    "total_3a_1": "$\\hat{R}_{Tot}^{3a,1}$",
    "total_3b_1": "$\\hat{R}_{Tot}^{3b,1}$",
    "total_3a_2": "$\\hat{R}_{Tot}^{3a,2}$",
    "total_3b_2": "$\\hat{R}_{Tot}^{3b,2}$",
    "bayes_1": "$\\hat{R}_{Bayes}^1$",
    "bayes_2": "$\\hat{R}_{Bayes}^2$",
    "bayes_3a": "$\\hat{R}_{Bayes}^{3a}$",
    "bayes_3b": "$\\hat{R}_{Bayes}^{3b}$",
    "excess_1_1": "$\\hat{R}_{Exc}^{1,1}$",
    "excess_2_1": "$\\hat{R}_{Exc}^{2,1}$",
    "excess_3a_1": "$\\hat{R}_{Exc}^{3a,1}$",
    "excess_3b_1": "$\\hat{R}_{Exc}^{3b,1}$",
    "excess_3a_2": "$\\hat{R}_{Exc}^{3a,2}$",
    "excess_3b_2": "$\\hat{R}_{Exc}^{3b,2}$",
    "random": "Random",
}

ACQ_ORDER = list(LABELS.keys())
LINESTYLES = ["-"] * len(ACQ_ORDER)
COLORS = []
COLORS += sns.color_palette("Blues", n_colors=7)[1:]
COLORS += sns.color_palette("Oranges", n_colors=5)[1:]
COLORS += sns.color_palette("Greens", n_colors=7)[1:]
COLORS += ["black"]

SCORING_RULES = ["crps", "log", "mse", "quadratic"]
SR_LABELS= {
    "crps": "Continuous Ranked Probability Score",
    "log": "Logarithmic Score",
    "mse": "Squared Error Score",
    "quadratic": "Quadratic Score",
}

os.makedirs(os.path.join(RESULTS_PATH_AL, "PERFS"), exist_ok=True)


def compute_performance(means: torch.Tensor, variances: torch.Tensor, y: torch.Tensor, type: str) -> torch.Tensor:
    """Compute per-iteration performance tensor.
    Shapes:
        means, variances: (acq, seeds, iterations, n_test)
        y: same broadcastable shape.
    Returns: perfs (acq, seeds, iterations)
    """
    if type == "mse":
        perfs = mse(means, y).mean(dim=-1)
    elif type == "nll":
        perfs = nll(means, variances, y).mean(dim=-1)
    elif type == "crps":
        perfs = crps(means, variances, y).mean(dim=-1)
    else:
        raise ValueError(f"Unknown performance type: {type}")
    return perfs


def load_runs(dataset: str, method: str, scoring_rule: str, seeds: List[int]):
    """Load all acquisition function runs for given dataset/method/scoring_rule specification.
    Returns means_runs, vars_runs, acquisition_functions, y_test.
    means_runs shape: (acq, seeds, iterations, n_test)
    """
    means_all, vars_all = [], []
    valid_acq = []
    for acq in ACQ_ORDER:
        # For random we always stored under crps in notebook pattern
        sr_key = "crps" if acq == "random" else scoring_rule
        means_acq, vars_acq = [], []
        n_iterations = None
        missing_seed = False
        for seed in seeds:
            pattern = os.path.join(RESULTS_PATH_AL, f"{dataset}_{method}_{sr_key}_{acq}_seed{seed}_*")
            files = glob.glob(pattern)
            files.sort()
            if not files:
                missing_seed = True
                break
            # determine number of iterations
            if n_iterations is None:
                args_path = os.path.join(files[-1], "args.txt")
                if not os.path.exists(args_path):
                    missing_seed = True
                    break
                with open(args_path, "r") as f:
                    for line in f:
                        k, *rest = [s.strip() for s in line.split(":")]
                        if k == "n_iterations":
                            n_iterations = int(rest[0])
                            break
            # load all iterations
            means_iters, vars_iters = [], []
            for i in range(n_iterations):
                means_path = os.path.join(files[-1], f"means_{i}.pt")
                vars_path = os.path.join(files[-1], f"vars_{i}.pt")
                if not (os.path.exists(means_path) and os.path.exists(vars_path)):
                    missing_seed = True
                    break
                means_tensor = torch.load(means_path, map_location=torch.device('cpu')).to(torch.float32)
                vars_tensor = torch.load(vars_path, map_location=torch.device('cpu')).to(torch.float32)
                means_iters.append(means_tensor)
                vars_iters.append(vars_tensor)
            if missing_seed:
                print(f"Missing iteration files for {dataset} {method} {scoring_rule} {acq} seed {seed}")
                break
            means_acq.append(torch.stack(means_iters, dim=0))  # (iterations, n_test)
            vars_acq.append(torch.stack(vars_iters, dim=0))
        if missing_seed or len(means_acq) == 0:
            continue
        means_all.append(torch.stack(means_acq, dim=0))  # (seeds, iterations, n_test)
        vars_all.append(torch.stack(vars_acq, dim=0))
        valid_acq.append(acq)

    if len(means_all) == 0:
        raise RuntimeError(f"No acquisition functions found for {dataset} {method} {scoring_rule}")

    means_runs = torch.stack(means_all, dim=0)  # (acq, seeds, iterations, n_test)
    vars_runs = torch.stack(vars_all, dim=0)

    # load y_test from random (crps) run of first seed
    pattern = os.path.join(RESULTS_PATH_AL, f"{dataset}_{method}_crps_random_seed{seeds[0]}_*")
    random_files = glob.glob(pattern)
    if not random_files:
        raise RuntimeError(f"y_test not found for {dataset} {method}")
    y_test = torch.load(os.path.join(random_files[-1], "y_test.pt"), map_location=torch.device('cpu'))

    return means_runs, vars_runs, valid_acq, y_test


def plot_grid(datasets: List[str], method: str, seeds: List[int], scoring_rules: List[str], output: str, show: bool):
    fig, axes = plt.subplots(len(datasets), len(scoring_rules), figsize=(4.5*len(scoring_rules), 3.2*len(datasets)), sharex=True, sharey=False)
    if len(datasets) == 1:
        axes = axes.reshape(1, -1)
    if len(scoring_rules) == 1:
        axes = axes.reshape(-1, 1)

    for i, dataset in enumerate(datasets):
        for j, sr in enumerate(scoring_rules):
            print(dataset, method, sr, seeds)
            ax = axes[i, j]

            if os.path.exists(os.path.join(RESULTS_PATH_AL, "PERFS", f"nll_{dataset}_{method}_{sr}_{'_'.join(map(str, seeds))}.pt")):
                perfs = torch.load(os.path.join(RESULTS_PATH_AL, "PERFS", f"nll_{dataset}_{method}_{sr}_{'_'.join(map(str, seeds))}.pt"))
                acqs = torch.load(os.path.join(RESULTS_PATH_AL, "PERFS", f"acqs_{dataset}_{method}_{sr}_{'_'.join(map(str, seeds))}.pt"))
            else:
                try:
                    means_runs, vars_runs, acqs, y_test = load_runs(dataset, method, sr, seeds)
                except RuntimeError as e:
                    ax.text(0.5, 0.5, str(e), ha='center', va='center', fontsize=8, wrap=True)
                    ax.set_axis_off()
                    continue

                # average over runs? In notebook they mean over last dim of means_runs (no ensemble dim here so keep seeds separate)
                # Expand y_test
                y_expand = y_test.view(1, 1, 1, -1, 1).expand(means_runs.shape[0], means_runs.shape[1], means_runs.shape[2], -1, means_runs.shape[-1])

                for pm in ["mse", "crps", "nll"]:
                    perfs = compute_performance(means_runs, vars_runs, y_expand, pm).mean(dim=-1)
                    torch.save(perfs, os.path.join(RESULTS_PATH_AL, "PERFS", f"{pm}_{dataset}_{method}_{sr}_{'_'.join(map(str, seeds))}.pt"))
                assert pm == "nll"
                torch.save(acqs, os.path.join(RESULTS_PATH_AL, "PERFS", f"acqs_{dataset}_{method}_{sr}_{'_'.join(map(str, seeds))}.pt"))

            for idx, acq in enumerate(acqs):
                
                if not perfs[idx].mean(dim=(0, 1)).isfinite():
                    # remove infinite values in tensor
                    perfs[idx][perfs[idx].isinf()] = torch.nan

                color = COLORS[ACQ_ORDER.index(acq)]
                ls = LINESTYLES[ACQ_ORDER.index(acq)]
                mean_curve = perfs[idx].mean(dim=0)  # mean over seeds -> (iterations)
                std_curve = perfs[idx].std(dim=0)
                ax.plot(mean_curve, label=f"${perfs[idx].nanmean(dim=(0, 1)):.3f}\\pm{perfs[idx].nanmean(dim=1).std(dim=0):.3f}$", color=color, linestyle=ls, linewidth=1.4, zorder=99)
                ax.fill_between(range(mean_curve.shape[0]), mean_curve - std_curve, mean_curve + std_curve, color=color, alpha=0.1, linewidth=0, zorder=1)
            if i == len(datasets) - 1:
                ax.set_xlabel("Iteration")
            if j == 0:
                ax.set_ylabel(f"NLL ({dataset.upper()})")
            else:
                ax.set_ylabel("")
            if i == 0:
                ax.set_title(SR_LABELS.get(sr, sr), fontsize=12)
            ax.grid(alpha=0.3, linewidth=0.5)

            ax.legend(loc="upper right", ncol=2, fontsize=6)
                
    # # share y limits across rows
    # for j in range(len(scoring_rules)):
    #     ylim = [ax.get_ylim() for ax in axes[:, j]]
    #     for i in range(len(datasets)):
    #         axes[i, j].set_ylim([min(y[0] for y in ylim), max(y[1] for y in ylim)])
    # # remove yticks except for the leftmost plot
    # for i in range(len(datasets)):
    #     if j > 0:
    #         axes[i, j].set_yticks([])

    # remove legend handles of ax
    for idx, acq in enumerate(acqs):
        color = COLORS[ACQ_ORDER.index(acq)]
        ls = LINESTYLES[ACQ_ORDER.index(acq)]
        ax.plot([], label=LABELS[acq], color=color, linestyle=ls, linewidth=1.4)
    handles, labels_list = axes[-1,-1].get_legend_handles_labels()
    handles = handles[-len(acqs):]
    labels_list = labels_list[-len(acqs):]
    # legend for acquisition functions
    fig.legend(handles, labels_list, loc='upper center', ncol= max(9, len(labels_list)), fontsize=10)
    # plt.tight_layout()
    plt.tight_layout(rect=(0,0,1,0.975))

    os.makedirs(os.path.dirname(output), exist_ok=True)
    fig.savefig(output, dpi=300)
    if show:
        plt.show()
    else:
        plt.close(fig)


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--datasets', nargs='+', default=['ymsd','sgemm', 'ccpp', 'casp','news', 'blog'], help='Datasets (rows)')
    p.add_argument('--method', default='ensemble', help='Method name (e.g., ensemble)')
    p.add_argument('--seeds', nargs='+', type=int, default=[42, 142, 242, 342, 442], help='Seeds used during AL')
    p.add_argument('--scoring-rules', nargs='+', default=SCORING_RULES, choices=SCORING_RULES, help='Scoring rules (columns)')
    p.add_argument('--output', default=os.path.join(PLOTS_PATH, 'active_learning_overview.pdf'), help='Output figure path')
    p.add_argument('--no-show', action='store_true', help='Do not display the figure')
    return p.parse_args()


def main():
    args = parse_args()
    plot_grid(args.datasets, args.method, args.seeds, args.scoring_rules, args.output, show=not args.no_show)
    print(f"Saved grid figure to {args.output}")


if __name__ == '__main__':
    main()
