

import numpy as np
import os.path as osp
from collections import OrderedDict, defaultdict
import torch
from sklearn.metrics import f1_score, confusion_matrix

from . import tent
from .build import EVALUATOR_REGISTRY
import torch.nn.functional as F

import pandas as pd

import torch
from torch import nn

from sklearn.metrics import auc
from sklearn import metrics




class EvaluatorBase:
    """Base evaluator."""

    def __init__(self, cfg):
        self.cfg = cfg

    def reset(self):
        raise NotImplementedError

    def process(self, mo, gt):
        raise NotImplementedError

    def evaluate(self):
        raise NotImplementedError


@EVALUATOR_REGISTRY.register()
class Classification_plain(EvaluatorBase):
    """Evaluator for classification."""

    def __init__(self, cfg, lab2cname=None, **kwargs):
        # 
        super().__init__(cfg)
        self._lab2cname = lab2cname #
        self._correct = 0
        self._total = 0

        self._correct_known = 0
        self._total_known = 0

        self._correct_unknown = 0
        self._total_unknown = 0

        self.save_path =  './txt_for_show/'
        self.known_entropy = []
        self.unknown_entropy  = []
        self.known_prob = []
        self.unknown_prob = []
        self.known_logits = []
        self.unknown_logits = []

        self.unknown_label = len(lab2cname)


        self.count = 0


        self._per_class_res = None
        self._y_true = []
        self._y_pred = []
        self._y_prob = []


        if cfg.TEST.PER_CLASS_RESULT:
            assert lab2cname is not None
            self._per_class_res = defaultdict(list)

    def reset(self):
        self._correct = 0
        self._total = 0
        self._correct_known = 0
        self._total_known = 0
        self._correct_unknown = 0
        self._total_unknown = 0

        self.known_entropy = []
        self.unknown_entropy  = []
        self.known_prob = []
        self.unknown_prob = []
        self.known_logits = []
        self.unknown_logits = []

        #for show feature details in paper:
        self.train_mean = []
        self.train_std  = []
        self.test_known_mean = []
        self.test_known_std = []
        self.test_unknown_mean = []
        self.test_unknown_std = []
        self.test_known_feat = []
        self.test_unknown_feat = []

        self.train_feat_label = []


        self._y_true = []
        self._y_pred = []
        self._y_prob = []
        if self._per_class_res is not None:
            self._per_class_res = defaultdict(list)

    def append_batch_tolist(self, data, list):
        B = data.size(0)
        for i in range(B):
            data_line = data[i].detach().cpu().numpy().tolist()
            list.append(data_line)

    def process(self, mo, gt, feat=None, model=None, num_classes=None, stage='val'):


        B =  mo.size(0)
        mo_ori  = mo
        prob_ori = F.softmax(mo_ori, dim=1)
        self.num_classes  = num_classes

        pred = mo.max(1)[1]
        matches = pred.eq(gt).float()

        # if self.unknown_label not in gt:
        gt_known = gt[gt != self.unknown_label]
        pred_known = pred[gt != self.unknown_label]


        matches_known = pred_known.eq(gt_known).float()

        self._correct_known += int(matches_known.sum().item())
        self._total_known += gt_known.shape[0]

        gt_unknown = gt[gt == self.unknown_label]
        pred_unknown = pred[gt == self.unknown_label]
        matches_unknown = pred_unknown.eq(gt_unknown).float()

        self._correct_unknown += int(matches_unknown.sum().item())
        self._total_unknown += gt_unknown.shape[0]

        self._correct += int(matches.sum().item())
        self._total += gt.shape[0]


        if self._per_class_res is not None:
            for i, label in enumerate(gt):
                label = label.item()
                matches_i = int(matches[i].item())
                self._per_class_res[label].append(matches_i)


    def evaluate(self):
        results = OrderedDict()
        acc = 100.0 * self._correct / self._total
        err = 100.0 - acc

        acc_known = 100.0 * self._correct_known / self._total_known
        err_known = 100.0 - acc_known

        if acc == acc_known:
            acc_unknown = 0.0
            err_unknown = 0.0
        else:
            acc_unknown = 100.0 * self._correct_unknown / self._total_unknown
            err_unknown = 100.0 - acc_unknown

        h_score = 2 * 100  * ((acc_known/100)*(acc_unknown/100)) / (acc_known/100+acc_unknown/100)

        results["accuracy"] = acc
        results["error_rate"] = err

        results["accuracy_known"] = acc_known
        results["error_rate_known"] = err_known
        results["accuracy_unknown"] = acc_unknown
        results["error_rate_unknown"] = err_unknown

        results["h_score"] = h_score


        print(
            "=> result\n"
            f"* total: {self._total:,}\n"
            f"* accuracy: {acc:.2f}%\n"

            f"* accuracy_known: {acc_known:.2f}%\n"

            f"* accuracy_unknown: {acc_unknown:.2f}%\n"

            f"* h_score: {h_score:.2f}%\n"
        )


        if self._per_class_res is not None:
            labels = list(self._per_class_res.keys())
            labels.sort()

            print("=> per-class result")
            accs = []
            labels_temp  = labels[:-1]
            for label in labels_temp:
                classname = self._lab2cname[label]
                res = self._per_class_res[label]
                correct = sum(res)
                total = len(res)
                acc = 100.0 * correct / total
                accs.append(acc)
                print(
                    "* class: {} ({})\t"
                    "total: {:,}\t"
                    "correct: {:,}\t"
                    "acc: {:.2f}%".format(
                        label, classname, total, correct, acc
                    )
                )
            mean_acc = np.mean(accs)
            print("* average: {:.2f}%".format(mean_acc))

            results["perclass_accuracy"] = mean_acc

        if self.cfg.TEST.COMPUTE_CMAT:
            cmat = confusion_matrix(
                self._y_true, self._y_pred, normalize="true"
            )
            save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
            torch.save(cmat, save_path)
            print('Confusion matrix is saved to "{}"'.format(save_path))

        return results


