import os
import re
from typing import Any, Dict, List

import pandas as pd  # type: ignore

EXPERIMENTS = [
    # "reptile/results/omniglot/CNN5/64-filter/transductive",
    # "maml/results/second-order/omniglot/CNN5/64-filter",
    "reptile/results/miniimagenet/CNN5/32-filter/transductive",
    "maml/results/second-order/miniimagenet/CNN5/32-filter",
]

# WAYSHOT = ["5-way/5-shot", "5-way/1-shot", "20-way/5-shot", "20-way/1-shot"]
WAYSHOT = ["5-way/1-shot", "5-way/5-shot"]

# METRIC_ROWS = ["accuracy", "nll", "ece", "aupr"]
METRIC_ROWS = ["ece"]

PATHS = {
    "standard": [
        re.compile(r"standard-.*run-\d\.csv"),
        re.compile(r"standard-.*run-\d\.csv"),
        re.compile(r"standard-.*run-\d\.csv"),
        re.compile("ood-.*run-\d\.csv")
    ],
    "corrupt-all": [re.compile(r"corrupt-.*run-\d-level-\d\.csv") for _ in range(4)],
    "ood-ece": [re.compile("ood-.*run-\d\.csv")]
}

PATH = "/st2/ml-baselines"
TABLE_TYPE = "ood-ece"


def df_files(filepaths: List[str], metric: str) -> pd.DataFrame:
    # read the first line of all the files and enforce that they are the same

    cols: Dict[str, Any] = {}
    for i, path in enumerate(filepaths):
        with open(path, "r") as f:
            if i == 0:
                cols = {k: [] for k in f.readline().split(",")}
            else:
                s = f.readline().split(",")
                if not all([u == v for (u, v) in zip(cols.keys(), s)]) and len(s) != len(cols.keys()):
                    raise ValueError(f"each file needs to have the same stats: {cols.keys()=} {s=}")

            for j, line in enumerate(f):
                pass  # go to the last line of the file which are the stats from the current run

            for stat, key in zip(line.split(","), cols.keys()):
                cols[key].append(float(stat))

    return pd.DataFrame(cols)


def print_results() -> None:
    stats = {}
    for experiment in EXPERIMENTS:
        for wayshot in WAYSHOT:
            tmp = experiment.split("/")
            model, exp = tmp[0], "/".join(tmp[1:])
            exp = os.path.join(exp, wayshot)

            if not stats.get(model, False):
                stats[model] = {}

            if not stats.get(exp, False):
                stats[model][exp] = {}

            path = os.path.join(PATH, model, exp)
            print(path)
            for metric, pattern in zip(METRIC_ROWS, PATHS[TABLE_TYPE]):
                if not os.path.exists(path):
                    print(f"no path exists at: {path=}")
                    continue

                files = [v for v in os.listdir(path) if bool(pattern.match(v))]
                files = [os.path.join(path, v) for v in files]

                if len(files) == 0:
                    print(f"no files for: {path=}")
                    stats[model][exp][metric] = "-"
                    continue

                out = df_files(files, metric)
                # print(f"{out}\n{experiment=}\n{wayshot=}\n{metric=}\n\n")
                mu, std = out.mean(), out.std()
                entry = f"{mu[metric]:.3f}" + r"$\pm$" + f"{std[metric]:.3f}"

                if metric in ["accuracy", "ece"]:
                    mu, std = mu * 100, std * 100
                    entry = f"{mu[metric]:.2f}" + r"$\pm$" + f"{std[metric]:.2f}"

                stats[model][exp][metric] = entry

    for experiment in EXPERIMENTS:
        print(f"\n{experiment=}")
        for metric in METRIC_ROWS:
            for wayshot in WAYSHOT:
                tmp = experiment.split("/")
                model, exp = tmp[0], "/".join(tmp[1:])
                exp = os.path.join(exp, wayshot)

                print(stats[model][exp][metric], end=" & ")
        print()


if __name__ == "__main__":
    print_results()
