import os
import re
from typing import Any, Dict, List

import pandas as pd  # type: ignore

MODELS = [
    "ProtoDDUCNN4",
    "ProtonetCNN4",
    "ProtonetCNN4SN",
    "ProtoSNGPCNN4",
    "ProtoMahalanobisCNN4Diag",
    "ProtoMahalanobisCNN4Rank-1",
    "ProtoMahalanobisCNN4Rank-2",
    "ProtoMahalanobisCNN4Rank-4",
    "ProtoMahalanobisCNN4Rank-8",
]

# METRIC_ROWS = ["accuracy", "nll", "ece", "aupr"]
METRIC_ROWS = ["ece"]

MAHALANOBIS_INFERENCE = ["softmax-sample", "softmax-sample", "softmax-sample", "distance"]
BASELINE_INFERENCE = ["distance" for _ in METRIC_ROWS]


DATASET = "miniimagenet"
DATASET_WAY_SHOT = {
    "omniglot": ["5-way/5-shot", "5-way/1-shot", "20-way/5-shot", "20-way/1-shot"],
    "miniimagenet": ["5-way/1-shot", "5-way/5-shot"]
}

PATHS = {
    "standard": [
        re.compile(r"standard-.*-run-\d\.csv"),
        re.compile(r"standard-.*-run-\d\.csv"),
        re.compile(r"standard-.*-run-\d\.csv"), re.compile("ood-.*-run-\d\.csv")
    ],
    "corrupt-all": [re.compile(r"corrupt-.*-run-\d-level-\d\.csv") for _ in range(4)],
    "ood-ece": [re.compile("ood-.*-run-\d\.csv")]
}

PATH = "/w2/path"
TABLE_TYPE = "ood-ece"


def df_files(filepaths: List[str], metric: str) -> pd.DataFrame:
    # read the first line of all the files and enforce that they are the same

    cols: Dict[str, Any] = {}
    for i, path in enumerate(filepaths):
        with open(path, "r") as f:
            if i == 0:
                cols = {k: [] for k in f.readline().split(",")}
            else:
                s = f.readline().split(",")
                if not all([u == v for (u, v) in zip(cols.keys(), s)]) and len(s) != len(cols.keys()):
                    raise ValueError(f"each file needs to have the same stats: {cols.keys()=} {s=}")

            for j, line in enumerate(f):
                pass  # go to the last line of the file which are teh stats from the current run

            for stat, key in zip(line.split(","), cols.keys()):
                cols[key].append(float(stat))

    return pd.DataFrame(cols)


def print_results() -> None:
    stats = {}
    for model in MODELS:
        if not stats.get(model, False):
            stats[model] = {}

        inference = BASELINE_INFERENCE
        if "Mahalanobis" in model:
            inference = MAHALANOBIS_INFERENCE

        for metric, infer, pattern in zip(METRIC_ROWS, inference, PATHS[TABLE_TYPE]):
            for experiment in DATASET_WAY_SHOT[DATASET]:
                if not stats[model].get(experiment, False):
                    stats[model][experiment] = {}

                addendum = "no-residual/t-1.0" if ("Mahalanobis" in model) else ""
                path = os.path.join(PATH, "results", DATASET, model, addendum, "softmax", experiment)

                if not os.path.exists(path):
                    print(f"no path exists at: {path=}")
                    continue

                files = [v for v in os.listdir(path) if bool(pattern.match(v))]
                files = [v for v in files if infer in v]
                files = [os.path.join(path, v) for v in files]

                if len(files) == 0:
                    print(f"no files for: {path=}")
                    stats[model][experiment][metric] = "-"
                    continue

                out = df_files(files, metric)
                mu, std = out.mean(), out.std()
                entry = f"{mu[metric]:.3f}" + r"$\pm$" + f"{std[metric]:.3f}"

                if metric in ["accuracy", "ece"]:
                    mu, std = mu * 100, std * 100
                    entry = f"{mu[metric]:.2f}" + r"$\pm$" + f"{std[metric]:.2f}"

                stats[model][experiment][metric] = entry

    print(stats)
    for model in MODELS:
        print(f"\n{model=} {METRIC_ROWS=} {DATASET_WAY_SHOT[DATASET]}")
        for metric in METRIC_ROWS:
            for experiment in DATASET_WAY_SHOT[DATASET]:
                print(stats[model][experiment][metric], end=" & ")
        print()


if __name__ == "__main__":
    print_results()
