import numpy as np
import torch
import torch.nn.functional as F
import logging
from typing import Optional, Union, List
from terminaltables import AsciiTable
from sklearn.metrics import roc_auc_score, f1_score

from utils.file_io import load_leison_classes

logger = logging.getLogger(__name__)


_EPS = 1e-10


class AverageMeter():
    """Computes and stores the average and current value"""
    def __init__(self) -> None:
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self) -> None:
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val : float, n : int = 1) -> None:
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def dice_coef(pred: Union[torch.Tensor, np.array],
              target : Union[torch.Tensor, np.array],
              eps : float = 1e-8) -> Union[torch.Tensor, np.array]:
    assert pred.shape == target.shape, "The shapes of input and target do not match"

    einsum = torch.einsum if type(pred) == torch.Tensor else np.enisum

    inter = einsum("ncij,ncij->nc", (pred, target))
    union = einsum("ncij->nc", pred) + einsum("ncij->nc", target)

    dice = (2 * inter + eps) / (union + eps)

    return dice


def iou_coef(pred: torch.Tensor,
             target : torch.Tensor,
             num_classes : int = 1) -> Union[torch.Tensor, np.array]:
    assert pred.shape == target.shape, "The shapes of input and target do not match"

    if num_classes > 1:
        pred = F.one_hot(pred, num_classes=num_classes).float()
        target = F.one_hot(target, num_classes=num_classes).float()

    inter = torch.einsum("nijc,nijc->nc", (pred, target))
    union = torch.einsum("nijc->nc", pred) + torch.einsum("nijc->nc", target) - inter

    iou = inter / (union + _EPS)

    return iou


class Metric:
    @property
    def num_samples(self):
        pass

    def update(self):
        pass

    def reset(self):
        pass

    def mean_score(self):
        pass

    def class_score(self):
        pass


class DiceMetric(Metric):
    """Dice score, which is equivalent to F1 score"""
    def __init__(self, thres : float = 0.5,
                 num_classes : int = 2,
                 classes_path : Optional[str] = None) -> None:
        self.thres = thres
        self.num_classes = num_classes
        if classes_path is not None:
            self.classes, self.classes_abbrev = load_leison_classes(classes_path)
        else:
            self.classes, self.classes_abbrev = None, None
        self.all_dices : Optional[torch.Tensor] = None  # [N, C, X]

    @property
    def num_samples(self):
        return self.all_dices.shape[0] if self.all_dices is not None else None

    def update(self, pred : torch.Tensor, target : torch.Tensor) -> torch.Tensor:
        pred = (pred > self.thres).float()

        dices = dice_coef(pred, target)

        if self.all_dices is None:
            self.all_dices = dices
        else:
            self.all_dices = torch.cat((self.all_dices, dices), dim=0)

        return dices.mean()

    def reset(self):
        self.all_dices = None

    def class_score(self) -> torch.Tensor:
        return self.all_dices.mean(dim=0)

    def mean_score(self) -> torch.Tensor:
        return self.all_dices.mean()

    def print_class_score(self) -> str:
        scores = self.class_score()
        for i in range(scores.shape[0]):
            logger.info("{}\t{:.4f}".format(i + 1, scores[i].item()))


def intersect_and_union(pred_label, label, num_classes, ignore_index):
    mask = (label != ignore_index)
    pred_label = pred_label[mask]
    label = label[mask]

    intersect = pred_label[pred_label == label]
    area_intersect, _ = np.histogram(
        intersect, bins=np.arange(num_classes + 1)
    )
    area_pred_label, _ = np.histogram(
        pred_label, bins=np.arange(num_classes + 1)
    )
    area_label, _ = np.histogram(
        label, bins=np.arange(num_classes + 1)
    )
    area_union = area_pred_label + area_label - area_intersect

    return area_intersect, area_union, area_pred_label, area_label


class IouMetric:
    def __init__(self, num_classes : int = 1,
                 classes : Optional[List[str]] = None,
                 ignore_index : int = -100) -> None:
        super().__init__()
        self.num_classes = num_classes
        if classes is None:
            self.classes = [str(x) for x in range(self.num_classes)]
        else:
            self.classes = classes
        self.ignore_index = ignore_index
        self.total_area_intersect = np.zeros((num_classes, ), dtype=np.float)
        self.total_area_union = np.zeros((num_classes, ), dtype=np.float)
        self.total_area_pred_label = np.zeros((num_classes, ), dtype=np.float)
        self.total_area_label = np.zeros((num_classes, ), dtype=np.float)

        self.num_samples = 0

    def reset(self):
        self.all_scores = None

    def mean_score(self):
        macc = (self.total_area_intersect / (np.spacing(1) + self.total_area_label)).mean()
        miou = (self.total_area_intersect / (np.spacing(1) + self.total_area_union)).mean()
        return macc, miou

    def update(self, pred : np.array, target : np.array, accumulate=True) -> torch.Tensor:
        # pred_labels = np.argmax(pred, dim=1)
        pred_labels = np.argmax(pred, axis=1)
        total_area_intersect = np.zeros((self.num_classes, ), dtype=np.float)
        total_area_union = np.zeros((self.num_classes, ), dtype=np.float)
        total_area_pred_label = np.zeros((self.num_classes, ), dtype=np.float)
        total_area_label = np.zeros((self.num_classes, ), dtype=np.float)
        for i in range(pred_labels.shape[0]):
            area_intersect, area_union, area_pred_label, area_label = (
                intersect_and_union(pred_labels[i],
                                    target[i],
                                    self.num_classes,
                                    self.ignore_index)
            )
            total_area_intersect  += area_intersect
            total_area_union += area_union
            total_area_pred_label += area_pred_label
            total_area_label += area_label
        if accumulate:
            self.total_area_intersect += total_area_intersect
            self.total_area_union += total_area_union
            self.total_area_pred_label += total_area_pred_label
            self.total_area_label += total_area_label
            self.num_samples += pred_labels.shape[0]
        acc = total_area_intersect.sum() / (np.spacing(1) + total_area_label.sum())
        iou = total_area_intersect.sum() / (np.spacing(1) + total_area_union.sum())

        return iou, acc

    def print_class_score(self) -> str:
        class_acc = self.total_area_intersect / (np.spacing(1) + self.total_area_label)
        class_iou = self.total_area_intersect / (np.spacing(1) + self.total_area_union)
        class_table_data = [["id"] + ["Class"] + ["IoU"] + ["acc"]]
        for i in range(class_acc.shape[0]):
            class_table_data.append(
                [i] + [self.classes[i]]
                + ["{:.4f}".format(class_iou[i])]
                + ["{:.4f}".format(class_acc[i])]
            )
        class_table_data.append(
            [""] + ["mean"]
            + ["{:.4f}".format(np.mean(class_iou))]
            + ["{:.4f}".format(np.mean(class_acc))]
        )
        table = AsciiTable(class_table_data)
        logger.info("\n" + table.table)


class RegressMetric:
    """Measuring regression results : MAE and RMSE"""
    def __init__(self) -> None:
        self.all_errors : Optional[torch.Tensor] = None  # [N, 1]

    def update(self, pred : torch.Tensor, target : torch.Tensor) -> torch.Tensor:
        errors = pred - target
        if self.all_errors is None:
            self.all_errors = errors
        else:
            self.all_errors = torch.cat((self.all_errors, errors), dim=0)

        return errors.abs().mean()

    @property
    def num_samples(self):
        return self.all_errors.shape[0] if self.all_errors is not None else None

    def mae(self) -> torch.Tensor:
        return self.all_errors.abs().mean()

    def rmse(self) -> torch.Tensor:
        return torch.sqrt(self.all_errors.square().mean())

    def __str__(self) -> str:
        rmse_score = self.rmse()
        mae_score = self.mae()
        return "RMSE - {:.4f} MAE - {:.4f}".format(rmse_score, mae_score)


class SegmentResult():
    """Wrapper of segmentation results"""
    def __init__(self) -> None:
        self.predicts : Optional[np.array] = None
        self.labels : Optional[np.array] = None

    def update(self, predicts : np.array, labels : np.array) -> None:
        if np.max(predicts) > 1 or np.min(predicts) < 0:
            raise ValueError("Invalid predictions")
        if self.predicts is None or self.labels is None:
            self.predicts = predicts
            self.labels = labels
        else:
            self.predicts = np.concatenate((self.predicts, predicts), axis=0)
            self.labels = np.concatenate((self.labels, labels), axis=0)

    def reset(self) -> None:
        self.predicts = None
        self.labels = None

    @property
    def num_samples(self) -> int:
        return self.predicts.shape[0]


class SegmentMetric(SegmentResult):
    """Evaluate segmentation metrics"""
    def __init__(self, num_classes : Optional[int] = None, multi_label : bool = False):
        super(SegmentMetric, self).__init__()
        self.num_classes = num_classes
        self.multi_label = multi_label

    def auc(self):
        """get the auc metric for current result"""
        # convert matrix from (N, num_classes, height, width) to (*, num_classes)
        predicts = np.reshape(np.einsum("kcij->kijc", self.predicts), (-1, self.num_classes))
        labels = np.reshape(self.labels, (-1,))
        if self.multi_label:
            auc = roc_auc_score(labels, predicts, average=None)
        else:
            auc = roc_auc_score(labels, predicts, multi_class="ovo", average="weighted")

        return auc

    def f1_score(self, threshold : float = 0.5):
        predicts = np.reshape(np.einsum("kcij->kijc", (self.predicts > threshold).astype(int)),
                              (-1, self.num_classes))
        labels = np.reshape(np.einsum("kcij->kijc", self.labels),
                            (-1, self.num_classes))
        if self.multi_label:
            score = f1_score(labels, predicts, average="samples")
        else:
            score = f1_score(labels, predicts, average="macro")

        return score
