"""Build table with nearood/farood accuracy for different postprocessors."""
import argparse
import os
import sys
import torch
import numpy as np
from collections import defaultdict, OrderedDict
from sklearn.metrics import roc_curve, auc

from openood.postprocessors.deed_postprocessor import DEEDPostprocessor
from openood.utils.config import Config


DATA_ORDER = OrderedDict()
DATA_ORDER["mnist"] = "MNIST"
DATA_ORDER["cifar10"] = "CIFAR10"
DATA_ORDER["cifar100"] = "CIFAR100"
DATA_ORDER["tin"] = "TIN"
ORDER = OrderedDict()
ORDER["de"] = "DE"
ORDER["ncl"] = "NCL"
ORDER["adp"] = "ADP"
ORDER["dice"] = "DICE"
ORDER["gradcam-min"] = "GradCAM"
#ORDER["gradcam-iou-logit-average"] = "GradCAM-IOU"


def parse_arguments():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("root", help="Path to dumped logits and labels for multiple methods")
    parser.add_argument("--aggregations", nargs="+", help="Aggregations to compare",
                        default=["average", "prob-average", "prob-minmax", "prob-maxstd", "logit-average", "minmax", "maxstd"])
    parser.add_argument("--single-method", help="Display comparison for a single method")
    return parser.parse_args()


def get_postprocessors(aggregations):
    postprocessors = {}
    for aggregation in aggregations:
        config = Config({
            "network": {
                "name": None
            },
            "postprocessor": {
                "postprocessor_args": {
                    "aggregation": aggregation,
                    "network_name": None,
                    "checkpoint_root": "",
                    "num_networks": -1
                }
            }
        })
        postprocessors[aggregation] = DEEDPostprocessor(config)
    return postprocessors


def collect(root):
    """Returns mapping (method, seed)->dataset->path."""
    data = defaultdict(lambda: defaultdict(dict))
    for (dirpath, dirnames, filenames) in os.walk(root):
        for filename in filenames:
            if not filename.endswith(".npz"):
                continue
            name = os.path.splitext(filename)[0]
            if (name == "test") or (name.count("-") == 1):
                method, seed = os.path.relpath(dirpath, root).split(os.sep)[:2]
                data[(method.rstrip("123456789"), seed)][name] = os.path.join(dirpath, filename)
    return data


def get_auroc(test_scores, ood_scores):
    scores = np.concatenate((test_scores, ood_scores), axis=0)
    labels = np.zeros(len(test_scores) + len(ood_scores), dtype=int)
    labels[:len(test_scores)] = 1
    fpr, tpr, _ = roc_curve(labels, scores)
    return auc(fpr, tpr)


def compute_metrics(paths, postprocessors):
    """Returns mapping aggregation -> split -> AUROC."""
    test_scores = {}  # Aggregation -> scores.
    ood_scores = defaultdict(dict)  # Aggregation -> (split, dataset) -> scores.
    for dataset, path in paths.items():
        data = np.load(path)
        logits, labels = torch.as_tensor(data["logits"]), torch.as_tensor(data["labels"])  # (N, B, C), (B).
        agg_logits = logits.mean(0)  # (B, C).
        for aggregation, postprocessor in postprocessors.items():
            pred, conf = postprocessor.postprocess(net=lambda data, return_ensemble: (agg_logits, logits), data=None)  # (B), (B).
            if "-" in dataset:
                split, baseset = dataset.split("-")
                ood_scores[aggregation][(split, baseset)] = conf
            elif dataset == "test":
                test_scores[aggregation] = conf
            else:
                raise ValueError(dataset)
    metrics = defaultdict(lambda: defaultdict(list))
    for aggregation, agg_scores in ood_scores.items():
        test_conf = test_scores[aggregation]
        aurocs = defaultdict(list)
        for (split, dataset), ood_conf in agg_scores.items():
            auroc = get_auroc(test_conf, ood_conf)
            aurocs[split].append(auroc)
        for split, values in aurocs.items():
            metrics[aggregation][split] = np.mean(values)
    return metrics


def boldify(field, position):
    parts = list(field.split(" / "))
    parts[position] = r"{\bf " + parts[position] + r"}"
    return " / ".join(parts)


def show(metrics, aggregations):
    # metrics: method->aggregation->split->[auroc for seed].
    # Compute results.
    table = OrderedDict()
    for dataset, dataset_name in DATA_ORDER.items():
        table[dataset] = OrderedDict()
        for aggregation in aggregations:
            table[dataset][aggregation] = OrderedDict()
            for method, method_name in ORDER.items():
                full_name = dataset + "-" + method
                if full_name not in metrics:
                    table[dataset][aggregation][method] = None
                    continue
                method_metrics = metrics[full_name]
                if aggregation not in method_metrics:
                    table[dataset][aggregation][method] = None
                    continue
                by_split = method_metrics[aggregation]
                near = np.mean(by_split["nearood"])
                far = np.mean(by_split["farood"])
                table[dataset][aggregation][method] = (near, far)
    # Find best.
    bold = {}
    for dataset, dataset_name in DATA_ORDER.items():
        bold[dataset] = {}
        for method, method_name in ORDER.items():
            top_near = 0
            top_far = 0
            top_near_agg = None
            top_far_agg = None
            for aggregation in aggregations:
                if table[dataset][aggregation][method] is None:
                    continue
                near, far = table[dataset][aggregation][method]
                if near > top_near:
                    top_near = near
                    top_near_agg = aggregation
                if far > top_far:
                    top_far = far
                    top_far_agg = aggregation
            bold[dataset][method] = (top_near_agg, top_far_agg)
    # Print.
    header = [r"\bf Dataset", "&", r"\bf Aggregation", "&"]
    for method in ORDER.values():
        header.append(r"\bf " + method)
        header.append("&")
    header[-1] = r"\\"
    print(" ".join(header))
    for dataset, dataset_name in DATA_ORDER.items():
        print(r"\midrule")
        print(r"\multirow{" + str(len(aggregations)) + r"}{*}{\bf " + dataset_name + "}")
        for aggregation in aggregations:
            tokens = ["", aggregation]
            for method, method_name in ORDER.items():
                if table[dataset][aggregation][method] is None:
                    tokens.append("")
                    continue
                near, far = table[dataset][aggregation][method]
                tokens.append(f"{near * 100:.2f} / {far * 100:.2f}")
                if bold[dataset][method][0] == aggregation:
                    tokens[-1] = boldify(tokens[-1], 0)
                if bold[dataset][method][1] == aggregation:
                    tokens[-1] = boldify(tokens[-1], 1)
            print("", " & ".join(tokens), r"\\")


def show_single(metrics, aggregations):
    # metrics: method->aggregation->split->[auroc for seed].
    method = next(iter(metrics)).split("-", 1)[1]
    header = [r"\bf Aggregation"]
    for dataset_name in DATA_ORDER.values():
        header.append(r"\bf " + dataset_name)
    print(" & ".join(header) + r"\\")
    print(r"\midrule")
    for aggregation in aggregations:
        tokens = [aggregation]
        for dataset, dataset_name in DATA_ORDER.items():
            full_name = dataset + "-" + method
            if full_name not in metrics:
                tokens.append("")
                continue
            method_metrics = metrics[full_name]
            if aggregation not in method_metrics:
                tokens.append("")
                continue
            by_split = method_metrics[aggregation]
            near = np.mean(by_split["nearood"])
            far = np.mean(by_split["farood"])
            near_std = np.std(by_split["nearood"])
            far_std = np.std(by_split["farood"])
            tokens.append(f"{near * 100:.2f}" +" {\\tiny $\\pm$ " + f"{near_std * 100:.2f}" + "} / " + f"{far * 100:.2f}" + " {\\tiny $\\pm$ " + f"{far_std * 100:.2f}" + "}")
        print("", " & ".join(tokens), r"\\")


def main(args):
    postprocessors = get_postprocessors(args.aggregations)
    data = collect(args.root)
    if args.single_method is not None:
        data = OrderedDict([(k, v) for k, v in data.items()
                            if k[0].split("-", 1)[1] == args.single_method])
    metrics = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))  # method -> aggregation -> split -> AUROCs per seed.
    for (method, seed), paths in data.items():
        print("Process", method, file=sys.stderr)
        seed_metrics = compute_metrics(paths, postprocessors)
        for aggregation, v1 in seed_metrics.items():
            for split, metric in v1.items():
                metrics[method][aggregation][split].append(metric)
    if args.single_method is not None:
        show_single(metrics, args.aggregations)
    else:
        show(metrics, args.aggregations)


if __name__ == "__main__":
    main(parse_arguments())
