import os
import re
from typing import List

import numpy as np  # type: ignore
import pandas as pd  # type: ignore

EXPERIMENT = "standard"
TO_PRINT = ["accuracy", "loss", "nll", "ece", "aupr", "auroc"]
MODELS = ["protonet", "mahalanobis"]
RUNS = 5


MODELS = [
    "ProtoDDUFC",
    "ProtonetFC",
    "ProtonetFCSN",
    "SNGPProtoFC",
    "ProtoMahalanobisDiag",
    "ProtoMahalanobisRank-1",
    "ProtoMahalanobisRank-2",
    "ProtoMahalanobisRank-4",
    "ProtoMahalanobisRank-8",
    "ProtoMahalanobisRank-16",
    "ProtoMahalanobisRank-32",
    "ProtoMahalanobisRank-64",
]

MAHALANOBIS_TRAIN = "softmax"
MAHALANOBIS_INFERENCE = "softmax-sample"

BASELINE_TRAIN = "softmax"
BASELINE_INFERENCE = "distance"

METRIC_ROWS = ["accuracy", "ece", "aupr", "loss"]

DATASET_WAY_SHOT = {
    "omniglot": ["20-way/1-shot", "20-way/5-shot", "5-way/1-shot", "5-way/5-shot"],
    # "miniimagenet": ["5-way/1-shot", "5-way/5-shot"]
}

PATHS = {
    "standard": [re.compile(r"standard-.*-run-\d\.csv"), re.compile(r"standard-.*-run-\d\.csv"), re.compile("ood-.*-run-\d\.csv"), re.compile(r"standard-.*-run-\d\.csv")],
    "corrupt-all": [re.compile(f"corrupt-.*-run-\d\.csv") for _ in range(4)],
    # TODO: doe sthe corrupt level chart need to go in a different file?
    # "corrupt-level": re.compile(r"corrupt-.*-run-\d-level-\d\.csv"),
}

PATH = "mahalanobis"
TABLE_TYPE = "standard"


def df_files(filepaths: List[str], metric: str) -> pd.DataFrame:
    # read the first line of all the files and enforce that they are the same

    cols = {}
    for i, path in enumerate(filepaths):
        with open(path, "r") as f:
            if i == 0:
                cols = {k: [] for k in f.readline().split(",")}
            else:
                s = f.readline().split(",")
                if not all([u == v for (u, v) in zip(cols.keys(), s)]) and len(s) != len(cols.keys()):
                    raise ValueError(f"each file needs to have the same stats: {cols.keys()=} {s=}")

            for j, line in enumerate(f):
                pass # go to the last line of the file which are teh stats from the current run

            for stat, key in zip(line.split(","), cols.keys()):
                cols[key].append(float(stat))

    return pd.DataFrame(cols)
    


def print_results() -> None:
    for model in MODELS:
        train = BASELINE_TRAIN
        inference = BASELINE_INFERENCE
        if "Mahalanobis" in model:
            train = MAHALANOBIS_TRAIN
            inference = MAHALANOBIS_INFERENCE

        for dataset in DATASET_WAY_SHOT:
            for metric, pattern in zip(METRIC_ROWS, PATHS[TABLE_TYPE]):
                for experiment in DATASET_WAY_SHOT[dataset]:
                    path = os.path.join(PATH, "results", dataset, model, train, experiment)

                    if not os.path.exists(path):
                        print(f"no path exists at: {path=}")
                        continue

                    files = [v for v in os.listdir(path) if bool(pattern.match(v))]
                    files = [v for v in files if inference in v]
                    files = [os.path.join(path, v) for v in files]

                    if len(files) == 0:
                        print(f"no files for: {path=}")
                        continue

                    # print(files)
                    out = df_files(files, metric)
                    print(out)


if __name__ == "__main__":
    print_results()
