"""Script to compute metrics for the pre-computed scores (LLR and L>k) for a HVAE"""

import argparse
import os
import logging

from collections import defaultdict
from typing import *

import rich
import numpy as np
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import sklearn.metrics
import matplotlib.pyplot as plt
import oodd.evaluators
import oodd.models
import oodd.losses
import oodd.utils


LOGGER = logging.getLogger()



# Helper methods
def collapse_multiclass_to_binary(y_true, zero_label=None):
    # Force the class index in zero_label to be zero and the others to collapse to 1
    zero_label_indices = y_true == zero_label
    y_true[zero_label_indices] = 0
    y_true[~zero_label_indices] = 1
    return y_true


def compute_roc_auc(y_true=None, y_score=None, zero_label=None):
    """Only binary"""
    y_true = collapse_multiclass_to_binary(y_true, zero_label)
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true, y_score)
    roc_auc = sklearn.metrics.roc_auc_score(y_true, y_score, average="macro")
    return roc_auc, fpr, tpr, thresholds


def compute_pr_auc(y_true=None, y_score=None, zero_label=None):
    """Only binary"""
    y_true = collapse_multiclass_to_binary(y_true, zero_label)
    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_true, y_score)
    pr_auc = sklearn.metrics.average_precision_score(y_true, y_score, average="macro")
    return pr_auc, precision, recall, thresholds


def compute_roc_pr_metrics(y_true, y_score, reference_class):
    roc_auc, fpr, tpr, thresholds = compute_roc_auc(y_true=y_true, y_score=y_score, zero_label=reference_class)
    pr_auc, precision, recall, thresholds = compute_pr_auc(y_true=y_true, y_score=y_score, zero_label=reference_class)
    idx_where_tpr_is_eighty = np.where((tpr - 0.8 >= 0))[0][0]
    fpr80 = fpr[idx_where_tpr_is_eighty]
    return (roc_auc, fpr, tpr, thresholds), (pr_auc, precision, recall, thresholds), fpr80


def get_dataset(file_name):
    return " ".join(file_name.split("-")[2:4])


def get_iw(file_name):
    iw_elbo = int(file_name.split("-")[5][7:])
    iw_lK = int(file_name.split("-")[6][5:-3])
    return iw_elbo, iw_lK


def get_k(file_name):
    return int(file_name.split("-k")[-1].split("-")[0])


def load_data(files, negate_scores: bool = False):
    data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(defaultdict))))
    for f in files:
        reference_dataset = get_dataset(f)
        iw_elbo, iw_elbo_k = get_iw(f)
        k = get_k(f)

        values = torch.load(os.path.join(args.source_dir, f))

        for test_dataset, values in values.items():
            values = np.array(values)
            data[reference_dataset][test_dataset][k][iw_elbo][iw_elbo_k] = values if not negate_scores else -values

    return data


def get_save_path(name):
    name = name.replace(" ", "-")
    return f"{args.source_dir}/{name}"


def write_text_file(filepath, string):
    with open(filepath, "w") as file_buffer:
        file_buffer.write(string)

def compute_results(score, score_name, model_name):
    results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(defaultdict))))
    for reference_dataset in score.keys():
        test_datasets = sorted(list(score[reference_dataset].keys()))
        s = f"========== {reference_dataset} (in-distribution) ==========\n"

        for test_dataset in test_datasets:
            if test_dataset == 'SVHN test' or test_dataset == 'MNIST test':
                k_values = sorted(list(score[reference_dataset][test_dataset].keys()))

                for k in k_values:
                    print(f'k={k}')
                    iw_elbos = sorted(list(score[reference_dataset][test_dataset][k].keys()))

                    for iw_elbo in iw_elbos:
                        iw_elbo_ks = sorted(list(score[reference_dataset][test_dataset][k][iw_elbo].keys()))

                        for iw_elbo_k in iw_elbo_ks:
                            reference_scores = score[reference_dataset][reference_dataset][k][iw_elbo][iw_elbo_k]
                            test_scores = score[reference_dataset][test_dataset][k][iw_elbo][iw_elbo_k]

                            # compute metrics
                            y_true = np.array([*[0] * len(reference_scores), *[1] * len(test_scores)])
                            y_score = np.concatenate([reference_scores, test_scores])

                            (
                                (roc_auc, fpr, tpr, thresholds),
                                (pr_auc, precision, recall, thresholds),
                                fpr80,
                            ) = compute_roc_pr_metrics(y_true=y_true, y_score=y_score, reference_class=0)
                            plt.plot(fpr, tpr, label=f'{save_name}')
                            # plt.title('ROC_curve')
                            plt.ylabel('True Positive Rate', fontsize=16)
                            plt.xlabel('False Positive Rate', fontsize=16)
                            plt.legend(loc=4)
                            plt.grid()
                            plt.show()



parser = argparse.ArgumentParser()
# =======> FASHIONMNIST
cifar_model_dir = "5-layer-vae-epoch1000"
# OURS: vae-dc-fashinmnist
# HVAE: VAE-FASIONMINST-EPOCH600-046043
# LVAE: LVAE-FASHIONMNIST-EPOCH400-773564
# BIVA: BIVA-FashionMNIST-EPOCH349-987397  biva-binaryFashionMNIST-epoch600


# ======> CIFAR10
# cifar_model_dir = "VAE-CIFAR10Dequantized-EPOCH300-075664"
# OURS: "vae-dc-warm-cifar10-0.978-epo51-454529"    "vae-dc-0.1-cifar10-0.96-epo32-317843"   VAEdc_CIFAR10Dequantized-2022-03-13-00-44-02.083833
# VAE : VAE-CIFAR10Dequantized-EPOCH300-075664
# LVAE: LVAE-CIFAR10Dequantized-EPOCH100-150420   LVAE-CIFAR10-EPOCH257-150420
# BIVA: BIVA-CIFAR10Dequantized-0.86-epoch1-946984    BIVA-choose-cifar10  BIVA-CIFAR10-GOOD  biva-cifar10-densityshow

save_name = '5-vanilla-vae'
Cifar10 = False
draw_density = False  # True
draw_roc = True  # False

parser.add_argument("--source_dir", type=str, default="./results/"+cifar_model_dir, help="directory from which to load scores")
# parser.add_argument("--source_dir", type=str, default="./results/", help="directory from which to load scores")

args = parser.parse_args()
rich.print(vars(args))



all_files = [f for f in os.listdir(args.source_dir) if f.endswith(".pt")]

all_scores = [f for f in all_files if "scores" in f]
all_elbo_k = [f for f in all_files if "elbos_k" in f]
all_kls = [f for f in all_files if "kl_sum" in f]

scores = load_data(all_scores, negate_scores=False)
kls = load_data(all_kls, negate_scores=False)
elbo_k = load_data(all_elbo_k, negate_scores=True)


# draw densities plot
if draw_density:
    if Cifar10:
        for reference_dataset in scores.keys():
            if reference_dataset == 'CIFAR10 test':
                test_datasets = sorted(list(scores[reference_dataset].keys()))
                s = f"========== {reference_dataset} (in-distribution) ==========\n"
                for test_dataset in test_datasets:
                    if test_dataset == 'SVHN test':
                        k_values = sorted(list(scores[reference_dataset][test_dataset].keys()))
                        for k in k_values:
                            # print(f'{k}')
                            if k == 2:
                                for iw_elbo in [1]:
                                    iw_elbo_ks = sorted(list(scores[reference_dataset][test_dataset][k][iw_elbo].keys()))
                                    for iw_elbo_k in iw_elbo_ks:
                                        reference_scores = scores[reference_dataset][reference_dataset][k][iw_elbo][iw_elbo_k]
                                        test_scores = scores[reference_dataset][test_dataset][k][iw_elbo][iw_elbo_k]
                                        all_scores = np.concatenate([reference_scores, test_scores])
                                        min_all, max_all = all_scores.min(), all_scores.max()
                                        nomalized_refer_scores = (reference_scores - min_all)/(max_all - min_all)
                                        nomalized_test_scores = (test_scores - min_all)/(max_all - min_all)
                                        plt.hist(nomalized_refer_scores, bins=100, density=True,  facecolor="deepskyblue", alpha=0.7, label='CIFAR10 test (in)')
                                        plt.hist(nomalized_test_scores, bins=100, density=True, facecolor="orangered", alpha=0.5, label='SVHN test (out)')
                                        plt.xlabel('Normalized $LLR^{>2}$', fontsize=18)
                                        plt.ylabel('Density', fontsize=18)
                                        plt.title('Trained on CIFAR10', fontsize=20)
                                        plt.legend(loc=1)
                                        plt.tight_layout()
                                        # plt.figure(figsize=(6, 8))
                                        plt.savefig(f'./figs/{save_name}-cifar10-density.pdf')
                                        plt.show()
    else:
        for reference_dataset in scores.keys():
            if reference_dataset == 'FashionMNIST test':
                test_datasets = sorted(list(scores[reference_dataset].keys()))
                s = f"========== {reference_dataset} (in-distribution) ==========\n"
                for test_dataset in test_datasets:
                    if test_dataset == 'MNIST test':
                        k_values = sorted(list(scores[reference_dataset][test_dataset].keys()))
                        for k in k_values:
                            if k == 2:
                                iw_elbos = sorted(list(scores[reference_dataset][test_dataset][k].keys()))
                                for iw_elbo in [1]:
                                    iw_elbo_ks = sorted(list(scores[reference_dataset][test_dataset][k][iw_elbo].keys()))
                                    for iw_elbo_k in iw_elbo_ks:
                                        reference_scores = scores[reference_dataset][reference_dataset][k][iw_elbo][iw_elbo_k]
                                        test_scores = scores[reference_dataset][test_dataset][k][iw_elbo][iw_elbo_k]
                                        all_scores = np.concatenate([reference_scores, test_scores])
                                        min_all, max_all = all_scores.min(), all_scores.max()
                                        nomalized_refer_scores = (reference_scores - min_all)/(max_all - min_all)
                                        nomalized_test_scores = (test_scores - min_all)/(max_all - min_all)
                                        plt.hist(nomalized_refer_scores, bins=100, density=True,  facecolor="deepskyblue", alpha=0.7, label='FashionMNIST test (in)')
                                        plt.hist(nomalized_test_scores, bins=100, density=True, facecolor="orangered", alpha=0.5, label='MNIST test (out)')
                                        plt.xlabel('Normalized $LLR^{>2}$', fontsize=18)
                                        plt.ylabel('Density', fontsize=18)
                                        plt.title('Trained on FashionMNIST', fontsize=20)
                                        plt.legend(loc=1)
                                        plt.tight_layout()
                                        # plt.figure(figsize=(6, 8))
                                        plt.savefig(f'./figs/{save_name}-fashionmnist-density.pdf')

                                        plt.show()


if draw_roc:
    results_scores = compute_results(scores, score_name="llr", model_name=save_name)
