import os
import re
from typing import Any, Dict, List

import pandas as pd  # type: ignore
import seaborn as sns  # type: ignore
from matplotlib import pyplot as plt  # type: ignore
from matplotlib.colors import to_rgba  # type: ignore
from utils import shift_level_box_plot

MODELS = [
    "ProtoDDUCNN4",
    "ProtonetCNN4",
    "ProtonetCNN4SN",
    "ProtoSNGPCNN4",
    "MAML",
    "Reptile",
    "ProtoMahalanobisCNN4Diag",
    "ProtoMahalanobisCNN4Rank-1",
    # "ProtoMahalanobisCNN4Rank-2",
    # "ProtoMahalanobisCNN4Rank-4",
    # "ProtoMahalanobisCNN4Rank-8",
]

MODEL_NAMES = [
    "ProtoDDU",
    "Protonet",
    "ProtonetSN",
    "ProtoSNGP",
    "MAML",
    "Reptile",
    "Ours (Diag)",
    "Ours (Rank-1)",
]

MODEL_PATHS = [
    "/st2/ml-baselines/mahalanobis",
    "/st2/ml-baselines/mahalanobis",
    "/st2/ml-baselines/mahalanobis",
    "/st2/ml-baselines/mahalanobis",
    "/st2/ml-baselines/maml/",
    "/st2/ml-baselines/reptile/",
    "/st2/ml-baselines/mahalanobis",
    "/st2/ml-baselines/mahalanobis",
]

METRIC_ROWS = ["accuracy", "nll", "ece", "aupr"]

MAHALANOBIS_INFERENCE = ["softmax-sample", "softmax-sample", "softmax-sample", "distance"]
BASELINE_INFERENCE = ["distance" for _ in METRIC_ROWS]

DATASETS = ["miniimagenet", "omniglot"]
DATASET_WAY_SHOT = {
    "omniglot": ["20-way/1-shot", "20-way/5-shot", "5-way/1-shot", "5-way/5-shot"],
    "miniimagenet": ["5-way/1-shot", "5-way/5-shot"]
}

PATHS = [re.compile(r"corrupt-.*run-\d-level-\d\.csv") for _ in range(4)]

METHODS_TO_COLORS = {
    "MAML": "tab:gray",
    "Reptile": "tab:cyan",
    "ProtoDDU": "tab:blue",
    "Protonet": "tab:orange",
    "ProtonetSN": "tab:green",
    "ProtoSNGP": "tab:red",
    "Ours (Diag)": "tab:pink",
    "Ours (Rank-1)": "tab:purple",
}


metrics = {"ece": "ECE", "nll": "NLL", "accuracy": "Accuracy", "aupr": "AUPR"}
datasets = {"omniglot": "Omniglot", "miniimagenet": "MiniImageNet"}


def plot_results() -> None:
    for dataset in DATASETS:
        stats: Any = {}
        for (model, model_path, model_name) in zip(MODELS, MODEL_PATHS, MODEL_NAMES):
            inference = BASELINE_INFERENCE
            if "Mahalanobis" in model:
                inference = MAHALANOBIS_INFERENCE

            for _metric, infer, pattern in zip(METRIC_ROWS, inference, PATHS):
                metric = metrics.get(_metric, "")
                if not metric:
                    raise ValueError("metric must be defined here")

                for experiment in DATASET_WAY_SHOT[dataset]:
                    if not stats.get(experiment, False):
                        stats[experiment] = {}

                    for level in range(6):
                        if not stats[experiment].get(metric, False):
                            stats[experiment][metric] = {v: [] for v in ["level", "value", "method"]}

                        filter_ = f"{32 if dataset == 'miniimagenet' else 64}-filter"
                        if "Proto" in model:
                            addendum = "no-residual/t-1.0" if ("Mahalanobis" in model) else ""
                            path = os.path.join(model_path, "results", dataset, model, addendum, "softmax", experiment)
                        elif model == "MAML":
                            path = os.path.join(model_path, "results", "second-order", dataset, "CNN5", filter_, experiment)
                        elif model == "Reptile":
                            path = os.path.join(model_path, "results", dataset, "CNN5", filter_, "transductive", experiment)
                        else:
                            raise ValueError(f"got an unknown model: {model=}")

                        if not os.path.exists(path):
                            print(f"no path exists at: {path=}")
                            continue

                        files = [v for v in os.listdir(path) if bool(pattern.match(v))]
                        if "Proto" in model:
                            files = [v for v in files if infer in v]

                        files = [os.path.join(path, v) for v in files]
                        files = [v for v in files if f"level-{level}" in v]

                        if len(files) == 0:
                            print(f"no files for: {path=}")
                            continue

                        for i, path in enumerate(files):
                            with open(path, "r") as f:
                                if i == 0:
                                    cols: Dict[str, List[float]] = {k: [] for k in f.readline().split(",")}
                                else:
                                    s = f.readline().split(",")
                                    if not all([u == v for (u, v) in zip(cols.keys(), s)]) and len(s) != len(cols.keys()):
                                        raise ValueError(f"each file needs to have the same key order of stats: {cols.keys()=} {s=}")

                                for j, line in enumerate(f):
                                    pass  # go to the last line of the file which are teh stats from the current run

                                for stat, key in zip(line.split(","), cols.keys()):
                                    if key == _metric:  # the original undereferenced metric from the stats file
                                        stats[experiment][metric]["level"].append(level if level > 0 else "Test")
                                        stats[experiment][metric]["value"].append(float(stat))
                                        stats[experiment][metric]["method"].append(model_name)

                                        if model == "ProtonetCNN4" and metric == "ece" and level == 0:
                                            print(f"{path=} ece: {float(stat)}")

                                        if metric == "accuracy" and float(stat) > 1.0:
                                            print(f"{path=} has invald accuracy {float(stat)}")

                                        if metric == "ece" and float(stat) > 1.0:
                                            print(f"{path=} has invald ece {float(stat)}")

                                        if metric == "aupr" and float(stat) > 1.0:
                                            print(f"{path=} has invald aupr {float(stat)}")

        """
        plot_data: a pd.DataFrame with columns "level", "value", and "method", used
          as the data for the box plots.
        y_label: the vertical label for the plot.
        """

        print(stats)
        for experiment in stats:
            for metric in stats[experiment]:

                print(f"plotting: {experiment=} {metric=}")
                df = pd.DataFrame(stats[experiment][metric])
                fig, ax = plt.gcf(), plt.gca()
                fig.set_size_inches(15, 4)

                shift_level_box_plot(ax, df, metric, METHODS_TO_COLORS, f"{datasets.get(dataset)}-C ({experiment})", hue_order=MODEL_NAMES)

                for i, _ in enumerate(ax.artists[::4]):
                    for j in range(4):
                        # r, g, b, a = to_rgba(get_color(j - 1))
                        box = ax.artists[(i * 4) + j]
                        r, g, b, a = to_rgba(box._facecolor)
                        box.set_facecolor((r, g, b, 0.8))
                        box.set_edgecolor((r, g, b, a))

                        for k in range(((i * 4) + j) * 6, ((i * 4) + j) * 6 + 6):
                            line = ax.lines[k]
                            line.set_color((r, g, b, a))
                            line.set_mfc((r, g, b, a))
                            line.set_mec((r, g, b, a))

                path = os.path.join(model_path, "results", dataset, "boxplots-2", experiment)
                print(path)
                os.makedirs(path, exist_ok=True)

                fig.tight_layout()
                fig.savefig(os.path.join(path, f"{metric}.png"))
                fig.savefig(os.path.join(path, f"{metric}.pdf"))
                plt.close()


if __name__ == "__main__":
    plot_results()
