import os
import re
from typing import Any, Dict, List

import numpy as np  # type: ignore
import pandas as pd  # type: ignore

EXPERIMENT = "standard"
TO_PRINT = ["accuracy", "loss", "nll", "ece", "aupr", "auroc"]
MODELS = ["protonet", "mahalanobis"]
RUNS = 5


MODELS = [
    "ProtoDDUCNN4",
    "ProtonetCNN4",
    "ProtonetCNN4SN",
    "ProtoSNGPCNN4",
    "ProtoMahalanobisCNN4Diag",
    "ProtoMahalanobisCNN4Rank-1",
    "ProtoMahalanobisCNN4Rank-2",
    "ProtoMahalanobisCNN4Rank-4",
    "ProtoMahalanobisCNN4Rank-8",
]

METRIC_ROWS = ["accuracy", "nll", "ece", "aupr"]

MAHALANOBIS_TRAIN = "softmax"
MAHALANOBIS_INFERENCE = ["softmax-sample", "softmax-sample", "softmax-sample", "distance"]

BASELINE_TRAIN = "softmax"
BASELINE_INFERENCE = ["distance" for _ in METRIC_ROWS]


DATASET_WAY_SHOT = {
    # "omniglot": ["20-way/1-shot", "20-way/5-shot", "5-way/1-shot", "5-way/5-shot"],
    "miniimagenet": ["5-way/1-shot", "5-way/5-shot"]
}

PATHS = {
    "standard": [re.compile(r"standard-.*-run-\d\.csv"), re.compile(r"standard-.*-run-\d\.csv"), re.compile(r"standard-.*-run-\d\.csv"), re.compile("ood-.*-run-\d\.csv")],
    "corrupt-all": [re.compile(r"corrupt-.*-run-\d-level-\d\.csv") for _ in range(4)],
    # "corrupt-level": re.compile(r"corrupt-.*-run-\d-level-\d\.csv"),
}

PATH = ""
TABLE_TYPE = "corrupt-all"


def df_files(filepaths: List[str], metric: str) -> pd.DataFrame:
    # read the first line of all the files and enforce that they are the same

    cols: Dict[str, Any] = {}
    for i, path in enumerate(filepaths):
        with open(path, "r") as f:
            if i == 0:
                cols = {k: [] for k in f.readline().split(",")}
            else:
                s = f.readline().split(",")
                if not all([u == v for (u, v) in zip(cols.keys(), s)]) and len(s) != len(cols.keys()):
                    raise ValueError(f"each file needs to have the same stats: {cols.keys()=} {s=}")

            for j, line in enumerate(f):
                pass  # go to the last line of the file which are teh stats from the current run

            for stat, key in zip(line.split(","), cols.keys()):
                cols[key].append(float(stat))

    return pd.DataFrame(cols)


def print_results() -> None:
    stats = {}
    for model in MODELS:
        if not stats.get(model, False):
            stats[model] = {}

        train = BASELINE_TRAIN
        inference = BASELINE_INFERENCE
        if "Mahalanobis" in model:
            train = MAHALANOBIS_TRAIN
            inference = MAHALANOBIS_INFERENCE

        for dataset in DATASET_WAY_SHOT:
            for metric, infer, pattern in zip(METRIC_ROWS, inference, PATHS[TABLE_TYPE]):
                for experiment in DATASET_WAY_SHOT[dataset]:
                    if not stats[model].get(experiment, False):
                        stats[model][experiment] = {}

                    for level in range(6):
                        if not stats[model][experiment].get(level, False):
                            stats[model][experiment][level] = {}

                        addendum = "no-residual/t-1.0" if ("Mahalanobis" in model) else ""
                        path = os.path.join(PATH, "results", dataset, model, addendum, train, experiment)

                        if not os.path.exists(path):
                            print(f"no path exists at: {path=}")
                            continue

                        files = [v for v in os.listdir(path) if bool(pattern.match(v))]
                        files = [v for v in files if infer in v]
                        files = [os.path.join(path, v) for v in files]
                        files = [v for v in files if f"level-{level}" in v]
                        print(files)

                        if len(files) == 0:
                            print(f"no files for: {path=}")
                            continue

                        out = df_files(files, metric)
                        mu, std = out.mean(), out.std()
                        entry = f"{mu[metric]:.3f}" + r"$\pm$" + f"{std[metric]:.3f}"
                        stats[model][experiment][level][metric] = entry

    print(stats)
    for model in MODELS:
        print(f"\n{model=} {METRIC_ROWS=} {DATASET_WAY_SHOT[dataset]}")
        for metric in METRIC_ROWS:
            for experiment in DATASET_WAY_SHOT[dataset]:
                print(f"{metric=} {experiment=}")
                for level in range(6):
                    print(stats[model][experiment][level][metric], end=" & ")
                print()
        print()


if __name__ == "__main__":
    print_results()
