import csv
import os
import numpy as np
import torch
from matplotlib import pyplot as plt
import torch.nn.functional as F
from .ood_evaluator import OODEvaluator
from typing import Dict, List
import torch.nn as nn
from torch.utils.data import DataLoader
from openood.utils import Config
from .metrics import compute_all_metrics
from ..postprocessors.edl_postprocessor import EDLPostprocessor


class OODEDLEvaluator(OODEvaluator):
    def __init__(self, config: Config):
        """OOD Evaluator.

        Args:
            config (Config): Config file from
        """
        super(OODEDLEvaluator, self).__init__(config)
        # if self.config.postprocessor.APS_mode:
        #
        #     assert 'val' in id_data_loaders
        #     assert 'val' in ood_data_loaders
        #     self.hyperparam_search(net, id_data_loaders['val'],
        #                            ood_data_loaders['val'], postprocessor)

    def eval_ood(self,
                 net: nn.Module,
                 id_data_loaders: Dict[str, DataLoader],
                 ood_data_loaders: Dict[str, Dict[str, DataLoader]],
                 postprocessor: EDLPostprocessor,
                 fsood: bool = False,
                 epoch_idx=-1):

        if type(net) is dict:
            for subnet in net.values():
                subnet.eval()
        else:
            net.eval()
        assert 'test' in id_data_loaders, \
            'id_data_loaders should have the key: test!'
        dataset_name = self.config.dataset.name

        print(f'Performing inference on {dataset_name} dataset...', flush=True)
        id_pred, id_conf, id_alpha, id_gt, id_uncertainty, id_logits = postprocessor.inference(
            net, id_data_loaders['test'])
        save_dir = os.path.join(self.config.output_dir, 'alpha', str(epoch_idx))
        os.makedirs(save_dir, exist_ok=True)
        np.save(os.path.join(save_dir, f'id_alpha.npy'), id_alpha)
        np.save(os.path.join(save_dir, f'id_labels.npy'), id_gt)

        if self.config.recorder.get('save_scores', False):
            self._save_scores(id_pred, id_conf, id_gt, dataset_name)
            self._save_uncertainties(id_uncertainty, dataset_name)

        # load nearood data and compute ood metrics
        print(u'\u2500' * 70, flush=True)

        self._eval_ood_(net, [id_pred, id_conf, id_gt, id_uncertainty, id_logits],
                        ood_data_loaders,
                        postprocessor,
                        ood_split='nearood',
                        epoch_idx=epoch_idx)

        # load farood data and compute ood metrics
        print(u'\u2500' * 70, flush=True)
        self._eval_ood_(net, [id_pred, id_conf, id_gt, id_uncertainty, id_logits],
                        ood_data_loaders,
                        postprocessor,
                        ood_split='farood',
                        epoch_idx=epoch_idx)

    def _eval_ood_(self,
                   net: nn.Module,
                   id_list: List[np.ndarray],
                   ood_data_loaders: Dict[str, Dict[str, DataLoader]],
                   postprocessor: EDLPostprocessor,
                   ood_split: str = 'nearood',
                   epoch_idx=-1):

        print(f'Processing {ood_split}...', flush=True)
        [id_pred, id_conf, id_gt, id_uncertainty, id_logits] = id_list

        metrics_list = []

        fig_settings = {
            'conf_dist': {
                'labels': ['ID Confidence', 'OOD Confidence'],
                'xlabel': 'Confidence'
            },
            'vacuity_dist': {
                'labels': ['ID Uncertainty', 'OOD Uncertainty'],
                'xlabel': 'Uncertainty'
            },
        }
        # for name in save_figs:
        save_figs = fig_settings.keys()
        num_row = len(fig_settings)
        num_col = len(ood_data_loaders[ood_split].keys())

        fig, axes = plt.subplots(num_row, num_col, figsize=(6 * num_col, 6 * num_row))
        vacuity_dict = dict()
        nearest_cosine_similarity_dict = dict()
        ood_conf_dict = dict()

        for col_idx, (dataset_name, ood_dl) in enumerate(ood_data_loaders[ood_split].items()):
            print(f'Performing inference on {dataset_name} dataset...',
                  flush=True)
            ood_pred, ood_conf, ood_alpha, ood_gt, ood_uncertainty, ood_logits = postprocessor.inference(net,
                                                                                                         ood_dl)

            vacuity_dict[dataset_name] = np.concatenate(ood_uncertainty['edl_vacuity'])
            ood_conf_dict[dataset_name] = ood_conf

            ood_gt = -1 * np.ones_like(ood_gt)  # hard set to -1 as ood
            if self.config.recorder.get('save_scores', False):
                self._save_scores(ood_pred, ood_conf, ood_gt, dataset_name)
                self._save_uncertainties(ood_uncertainty, dataset_name)

            save_dir = os.path.join(self.config.output_dir, 'alpha', str(epoch_idx))
            os.makedirs(save_dir, exist_ok=True)
            np.save(os.path.join(save_dir, f'ood_{dataset_name}_alpha.npy'), ood_alpha)

            pred = np.concatenate([id_pred, ood_pred])
            label = np.concatenate([id_gt, ood_gt])

            uncertainty = dict()
            ood_metrics = dict()
            for fn in ood_uncertainty:
                uncertainty[fn] = np.concatenate([-id_uncertainty[fn], -ood_uncertainty[fn]])
                print(f'Computing metrics on {dataset_name} dataset with uncertainty method {fn}')
                ood_metrics[fn] = compute_all_metrics(uncertainty[fn], label, pred)

            for row_idx, name in enumerate(save_figs):
                # row_idx = save_figs.index(name)
                ax = axes[row_idx][col_idx]
                ax.set_title(dataset_name)

                if name == 'conf_dist':
                    A = id_conf
                    B = ood_conf
                elif name == 'vacuity_dist':
                    A = id_uncertainty['edl_vacuity']
                    B = ood_uncertainty['edl_vacuity']
                else:
                    raise NotImplementedError

                plot_dist(A,
                          B,
                          ax=ax,
                          label_A=fig_settings[name]['labels'][0],
                          label_B=fig_settings[name]['labels'][1])
                ax.set_xlabel(fig_settings[name]['xlabel'], fontsize='x-large')
                ax.set_ylabel('Density', fontsize='x-large')
                ax.legend()
            save_dir = os.path.join(self.config.output_dir, 'plot', 'kde', str(epoch_idx))
            os.makedirs(save_dir, exist_ok=True)
            file_name = os.path.join(save_dir, f'{ood_split}.pdf')
            # fig.savefig(file_name, format='pdf', bbox_inches='tight')
            plt.savefig(file_name, transparent=True, pad_inches=0.0, bbox_inches='tight', dpi=600)
            # if self.config.recorder.save_csv:
            self._save_csv(ood_metrics, dataset_name=dataset_name, epoch_idx=epoch_idx)

            metrics_list.append(ood_metrics)

        print('Computing mean metrics...', flush=True)
        # compute mean metrics
        split_by_uncertainty_method = dict()
        for metric in metrics_list:
            for key in metric.keys():
                if key not in split_by_uncertainty_method:
                    split_by_uncertainty_method[key] = []
                split_by_uncertainty_method[key].append(metric[key])

        for key in split_by_uncertainty_method:
            split_by_uncertainty_method[key] = np.mean(split_by_uncertainty_method[key], axis=0)

        self._save_csv(split_by_uncertainty_method, dataset_name=ood_split, epoch_idx=epoch_idx)

    def eval_acc(self,
                 net: nn.Module,
                 data_loader: DataLoader,
                 postprocessor: EDLPostprocessor = None,
                 epoch_idx=-1,
                 fsood: bool = False,
                 csid_data_loaders: DataLoader = None):
        """Returns the accuracy score of the labels and predictions.

        :return: float
        """
        if type(net) is dict:
            net['backbone'].eval()
        else:
            net.eval()
        # id_pred, id_conf, id_gt, id_uncertainty, id_logits, id_mean_feat_norm, id_mean_ncs
        id_pred, id_conf, id_alpha, id_gt, uncertainty_list, id_logits = postprocessor.inference(
            net, data_loader)
        one_hot_target = F.one_hot(torch.tensor(id_gt), num_classes=int(self.config.network.num_classes))
        try:
            loss = postprocessor.edl_utils.get_edl_loss(torch.tensor(id_logits), one_hot_target)
        except ValueError as e:
            print(e)
            loss = torch.tensor(0)
        ece = self.eval_ece(total_scores=id_conf,
                            total_preds=id_pred,
                            total_labels=id_gt)

        self._save_scores(id_pred, id_conf, id_gt, str(epoch_idx))
        self._save_uncertainties(uncertainty_list, str(epoch_idx))

        metrics = {}
        metrics['acc'] = sum(id_pred == id_gt) / len(id_pred)
        metrics['epoch_idx'] = epoch_idx
        metrics['loss'] = loss.item()
        metrics['ece'] = ece
        metrics['alpha'] = id_alpha.sum(axis=-1).mean()
        K = self.config.dataset.num_classes
        metrics['vacuity'] = K / metrics['alpha']
        csv_path = os.path.join(self.config.output_dir, 'ece.csv')
        write_content = {
            'split': epoch_idx,
            'ACC': '{:.2f}'.format(100 * metrics['acc']),
            'ECE': '{:.2f}'.format(100 * metrics['ece']),
        }
        fieldnames = list(write_content.keys())
        if not os.path.exists(csv_path):
            with open(csv_path, 'w', newline='') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
                writer.writerow(write_content)
        else:
            with open(csv_path, 'a', newline='') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                if fsood:
                    writer.writeheader()
                writer.writerow(write_content)
        return metrics

    def eval_ece(self, total_scores, total_preds, total_labels, num_bins=15):
        scores_np = np.reshape(total_scores, -1)
        preds_np = np.reshape(total_preds, -1)
        labels_np = np.reshape(total_labels, -1)
        acc_tab = np.zeros(num_bins)  # Empirical (true) confidence
        mean_conf = np.zeros(num_bins)  # Predicted confidence
        nb_items_bin = np.zeros(num_bins)  # Number of items in the bins
        tau_tab = np.linspace(0, 1, num_bins + 1)  # Confidence bins
        for i in np.arange(num_bins):  # Iterates over the bins
            # Selects the items where the predicted max probability falls in the bin
            # [tau_tab[i], tau_tab[i + 1)]
            sec = (tau_tab[i + 1] > scores_np) & (scores_np >= tau_tab[i])
            nb_items_bin[i] = np.sum(sec)  # Number of items in the bin
            # Selects the predicted classes, and the true classes
            class_pred_sec, y_sec = preds_np[sec], labels_np[sec]
            # Averages of the predicted max probabilities
            mean_conf[i] = np.mean(
                scores_np[sec]) if nb_items_bin[i] > 0 else np.nan
            # Computes the empirical confidence
            acc_tab[i] = np.mean(
                class_pred_sec == y_sec) if nb_items_bin[i] > 0 else np.nan
        # Cleaning
        mean_conf = mean_conf[nb_items_bin > 0]
        acc_tab = acc_tab[nb_items_bin > 0]
        nb_items_bin = nb_items_bin[nb_items_bin > 0]
        if sum(nb_items_bin) != 0:
            ece = np.average(
                np.absolute(mean_conf - acc_tab),
                weights=nb_items_bin.astype(float) / np.sum(nb_items_bin))
        else:
            print('Warning: sum(nb_items_bin)==0 in eval_ece function')
            ece = -1.0
        return ece

    def _save_csv(self, metrics, dataset_name, epoch_idx=-1):

        for fn in metrics:
            [fpr, auroc, aupr_in, aupr_out, accuracy] = metrics[fn]

            write_content = {
                'dataset': dataset_name,
                'uncertainty': fn,
                'FPR@95': '{:.2f}'.format(100 * fpr),
                'AUROC': '{:.2f}'.format(100 * auroc),
                'AUPR_IN': '{:.2f}'.format(100 * aupr_in),
                'AUPR_OUT': '{:.2f}'.format(100 * aupr_out),
                'ACC': '{:.2f}'.format(100 * accuracy),
            }

            fieldnames = list(write_content.keys())

            # print ood metric results
            print('Uncertainty: {}, FPR@95: {:.2f}, AUROC: {:.2f}'.format(fn, 100 * fpr, 100 * auroc),
                  end=' ',
                  flush=True)
            print('AUPR_IN: {:.2f}, AUPR_OUT: {:.2f}'.format(
                100 * aupr_in, 100 * aupr_out),
                flush=True)
            print('ACC: {:.2f}'.format(accuracy * 100), flush=True)
            print(u'\u2500' * 70, flush=True)

            save_dir = os.path.join(self.config.output_dir, 'ood_metrics')
            os.makedirs(save_dir, exist_ok=True)
            csv_path = os.path.join(save_dir, f'{epoch_idx}.csv')
            if not os.path.exists(csv_path):
                with open(csv_path, 'w', newline='') as csvfile:
                    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                    writer.writeheader()
                    writer.writerow(write_content)
            else:
                with open(csv_path, 'a', newline='') as csvfile:
                    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                    writer.writerow(write_content)

    def _save_scores(self, pred, conf, gt, save_name):
        save_dir = os.path.join(self.config.output_dir, 'scores')
        os.makedirs(save_dir, exist_ok=True)
        np.savez(os.path.join(save_dir, save_name),
                 pred=pred,
                 conf=conf,
                 label=gt)

    def _save_uncertainties(self, ood_uncertainty, save_name):
        for fn in ood_uncertainty:
            save_dir = os.path.join(self.config.output_dir, 'uncertainties', fn)
            os.makedirs(save_dir, exist_ok=True)
            np.savez(os.path.join(save_dir, save_name), uncertainty=ood_uncertainty[fn])


def plot_dist(A, B, ax, label_A, label_B):
    import seaborn as sns

    sns.kdeplot(A, shade=False, ax=ax,
                label=label_A)
    sns.kdeplot(B, shade=False, ax=ax, label=label_B)
