#!/usr/bin/env python
# encoding: utf-8

# filename: model
import torch

from .utils.loss_utils import edl_loss, mse_loss, focal_edl_loss, softmax_cross_entropy_loss
from typing import Optional, Union, Callable
from enum import Enum
from pydantic import BaseModel, Field
from .utils.activation_utils import *
from .utils.prior_utils import *
from .utils.loss_utils import loglikelihood_bias, loglikelihood_variance, compute_fisher_inverse


class EvidenceOptions(str, Enum):
    exp = 'exp'
    relu = 'relu'
    softplus = 'softplus'
    none = 'none'

class PriorOptions(str, Enum):
    jefferys = 'jefferys'
    uniform = 'uniform'
    uniform_divided = 'uniform_divided'
    zero = 'zero'


class UncertaintyOptions(str, Enum):
    edl_vacuity = 'edl_vacuity'
    edl_shannon_entropy = 'edl_shannon_entropy'
    edl_mutual_information = 'edl_mutual_information'
    edl_expected_data_uncertainty = 'edl_expected_data_uncertainty'
    edl_dissonance = 'edl_dissonance'
    edl_dirichlet_entropy = 'edl_dirichlet_entropy'
    softmax_shannon_entropy = 'softmax_shannon_entropy'
    _max_softmax_probability = '_max_softmax_probability'
    _max_belief = '_max_belief'
    _max_probability = '_max_probability'
    _max_alpha = '_max_alpha'
    edl_bias = 'edl_bias'
    edl_variance = 'edl_variance'
    edl_cpu = 'edl_cpu'
    edl_margin = 'edl_margin'

class BeliefOptions(str, Enum):
    max_softmax_probability = 'max_softmax_probability'
    max_belief = 'max_belief'
    max_probability = 'max_probability'


class EdlLossOptions(str, Enum):
    edl_mse_loss = 'edl_mse_loss'
    mse_loss = 'mse_loss'
    edl_log_loss = 'edl_log_loss'
    cross_entropy_loss = 'cross_entropy_loss'
    edl_cross_entropy_loss = 'edl_cross_entropy_loss'
    focal_edl_cross_entropy_loss = 'focal_edl_cross_entropy_loss'
    softmax_cross_entropy_loss = 'softmax_cross_entropy_loss'
    edl_reconciled_loss = 'edl_reconciled_loss'


class EDL(BaseModel):
    K: int = Field(ge=2)
    evi_fn: Optional[Union[EvidenceOptions, Callable]] = None
    prior_fn: Optional[Union[PriorOptions, Callable, float]] = None
    uncertainty_fn: Optional[Union[UncertaintyOptions, Callable]] = None
    belief_fn: Optional[Union[BeliefOptions, Callable]] = None
    loss_fn: Optional[Union[EdlLossOptions, Callable]] = None
    logits_is_alpha: bool = False  # Set the default value to False
    ngd: bool = False    # use natural gradient descent

    def set_evidence_fn(self, evidence_fn: Union[EvidenceOptions, Callable]):
        """
        :param evidence_fn:
        :return:
        """
        self.evi_fn = evidence_fn

    def set_prior_fn(self, prior_fn: Union[PriorOptions, Callable]):
        self.prior_fn = prior_fn

    def set_uncertainty_fn(self, uncertainty_fn: Union[UncertaintyOptions, Callable]):
        self.uncertainty_fn = uncertainty_fn

    def set_belief_fn(self, belief_fn: Union[BeliefOptions, Callable]):
        self.belief_fn = belief_fn

    def get_edl_loss(self, logits: torch.Tensor, label: torch.Tensor, **kwargs):
        assert logits.shape[-1] == label.shape[-1] == self.K

        if isinstance(self.loss_fn, Callable):
            return self.loss_fn(logits, label, **kwargs)

        if self.loss_fn == EdlLossOptions.edl_cross_entropy_loss:
            return self.edl_cross_entropy_loss(logits, label)
        elif self.loss_fn == EdlLossOptions.edl_log_loss:
            return self.edl_log_loss(logits, label)
        elif self.loss_fn == EdlLossOptions.edl_mse_loss:
            return self.edl_mse_loss(logits, label)
        elif self.loss_fn == EdlLossOptions.mse_loss:
            return self.mse_loss(logits, label)
        elif self.loss_fn == EdlLossOptions.focal_edl_cross_entropy_loss:
            return self.focal_edl_cross_entropy_loss(logits, label, gamma=kwargs.get('gamma'),
                                                     beta=kwargs.get('beta'))
        elif self.loss_fn == EdlLossOptions.cross_entropy_loss:
            return self.cross_entropy_loss(logits, label)
        elif self.loss_fn == EdlLossOptions.softmax_cross_entropy_loss:
            return softmax_cross_entropy_loss(logits, label)
        elif self.loss_fn == EdlLossOptions.edl_reconciled_loss:
            return self.edl_reconciled_loss(logits, label)
        else:
            raise ValueError("Please set loss_fn")

    def get_uncertainty(self, logits):

        assert self.uncertainty_fn is not None
        if isinstance(self.uncertainty_fn, Callable):
            return self.uncertainty_fn(logits)
        else:
            return self.__getattribute__(self.uncertainty_fn)(logits)

    def get_belief(self, logits):
        assert self.belief_fn is not None
        if isinstance(self.belief_fn, Callable):
            return self.belief_fn(logits)
        else:
            return self.__getattribute__(self.belief_fn)(logits)

    def softmax_shannon_entropy(self, logits):
        prob = F.softmax(logits, dim=1)
        prob = prob / prob.sum(dim=1, keepdim=True)
        # print(1111)

        from torch.distributions import Categorical
        categorical = Categorical(probs=prob)
        return categorical.entropy()

    def logits_to_evidence(self, logits: torch.Tensor, reconciled=False) -> torch.Tensor:
        if isinstance(self.evi_fn, Callable):
            evi = self.evi_fn(logits)
        elif isinstance(self.evi_fn, EvidenceOptions):
            evi = eval(self.evi_fn)(logits)
        else:
            raise ValueError

        if reconciled:
            evi = evi - torch.min(evi, dim=1, keepdim=True)[0]
        return evi

    def logits_to_alpha(self, logits: torch.Tensor) -> torch.Tensor:
        """
        convert logits to alpha
        :param logits:
        :return:
        """
        if self.logits_is_alpha:
            return logits

        evidence = self.logits_to_evidence(logits)
        prior = self.get_priors()
        alpha = evidence + prior.to(evidence.device)
        # alpha.retain_grad()
        # fisher, fisher_inv = compute_fisher_inverse(alpha.detach())
        #
        # def modify_grad(grad):
        #     return torch.bmm(fisher_inv, grad.unsqueeze(2)).squeeze(2)
        # if self.ngd:
        #     alpha.register_hook(modify_grad)

        return alpha

    def edl_vacuity(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        vacuity = self.K / torch.sum(alpha, dim=1, keepdim=True)
        return vacuity

    def edl_cpu(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        prior = self.get_priors()
        prior_sum = prior.sum()
        alpha_y_hat = alpha.max(dim=1, keepdim=True)[0]
        alpha_left = torch.sum(alpha, dim=1, keepdim=True)-alpha_y_hat
        cpu = prior_sum / ((self.K-1)*alpha_y_hat - alpha_left+prior_sum)
        return cpu

    def edl_margin(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        prior = self.get_priors()
        prior_sum = prior.sum()
        alpha_y_hat = alpha.max(dim=1, keepdim=True)[0]
        alpha_sorted, _ = alpha.sort(dim=1, descending=True)
        alpha_second_max = alpha_sorted[:, 1].reshape(alpha_y_hat.shape)

        margin = prior_sum / (alpha_y_hat - alpha_second_max + prior_sum)
        return margin

    def edl_dirichlet_entropy(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        from torch.distributions import Dirichlet
        dirichlet = Dirichlet(alpha)
        return dirichlet.entropy()

    def _max_softmax_probability(self, logits: torch.Tensor):
        return -self.max_softmax_probability(logits)

    def _max_probability(self, logits: torch.Tensor):
        return -self.max_probability(logits)

    def _max_alpha(self, logits: torch.Tensor):
        return -self.max_alpha(logits)

    def _max_belief(self, logits: torch.Tensor):
        return -self.max_belief(logits)

    def max_softmax_probability(self, logits: torch.Tensor):
        prob = F.softmax(logits, dim=1)
        return prob.max(dim=1)[0]

    def max_alpha(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        return alpha.max(dim=1)[0]

    def max_belief(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        evidence = self.logits_to_evidence(logits)
        S = torch.sum(alpha, dim=-1)
        belief = evidence / S.unsqueeze(dim=-1)
        return belief.max(dim=1)[0]

    def logits_to_prob(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        prob = alpha / torch.sum(alpha, dim=1, keepdim=True)  # gt prob
        return prob

    def edl_shannon_entropy(self, logits: torch.Tensor):
        prob = self.logits_to_prob(logits)
        from torch.distributions import Categorical
        categorical = Categorical(probs=prob)
        return categorical.entropy()

    def edl_mutual_information(self, logits: torch.Tensor):
        return self.edl_shannon_entropy(logits) - self.edl_expected_data_uncertainty(logits)

    def edl_expected_data_uncertainty(self, logits):
        alpha = self.logits_to_alpha(logits)
        S = torch.sum(alpha, dim=-1)

        edu = [(alpha[:, i] / S) * (torch.digamma(S + 1) - torch.digamma(alpha[:, i] + 1)) for i in range(self.K)]
        edu = torch.sum(torch.stack(edu, dim=1), dim=1)
        return edu

    def edl_dissonance(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        evidence = self.logits_to_evidence(logits)
        S = torch.sum(alpha, dim=-1)
        belief = evidence / S.unsqueeze(dim=-1)
        try:
            from .ops.dissonance.dissonance_utils import dissonance_gpu
            return dissonance_gpu(belief)
        except ImportError:
            from .ops.dissonance.dissonance_utils import dissonance_cpu
            return dissonance_cpu(belief)

    def max_probability(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        prob = alpha / torch.sum(alpha, dim=1, keepdim=True)  # gt prob
        return prob.max(dim=1)[0]

    def edl_cross_entropy_loss(self, logits: torch.Tensor, target: torch.Tensor):
        """
        expected cross entropy loss
        :param output:
        :param target:
        :return:
        """
        alpha = self.logits_to_alpha(logits)
        loss = torch.mean(edl_loss(torch.digamma, target, alpha))
        return loss

    def cross_entropy_loss(self, logits: torch.Tensor, target: torch.Tensor):
        # alpha = self.logits_to_alpha(logits)
        labels = torch.argmax(target, dim=1)
        return F.cross_entropy(logits, labels)

    def edl_reconciled_loss(self, logits: torch.Tensor, target: torch.Tensor):
        evidence = self.logits_to_evidence(logits, reconciled=True)
        prior = self.get_priors()
        alpha = evidence + prior.to(evidence.device)
        loss = torch.mean(edl_loss(torch.digamma, target, alpha))
        return loss

    def edl_log_loss(self, logits: torch.Tensor, target: torch.Tensor):
        """
        expected negative log likelihood loss
        :param logits:
        :param target:
        :return:
        """
        alpha = self.logits_to_alpha(logits)
        loss = torch.mean(edl_loss(torch.log, target, alpha))
        return loss

    def edl_mse_loss(self, logits: torch.Tensor, target: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        loss = torch.mean(mse_loss(target, alpha))
        return loss

    def mse_loss(self, logits: torch.Tensor, target: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        loss = torch.mean(mse_loss(target, alpha, with_bias=False))
        return loss

    def edl_bias(self, logits: torch.Tensor, target: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        return loglikelihood_bias(target, alpha)

    def edl_variance(self, logits: torch.Tensor):
        alpha = self.logits_to_alpha(logits)
        return loglikelihood_variance(alpha)

    def focal_edl_cross_entropy_loss(self, logits: torch.Tensor, target: torch.Tensor, gamma=1, beta=None):
        if beta is None:
            beta = gamma
        alpha = self.logits_to_alpha(logits)
        loss = torch.mean(focal_edl_loss(torch.digamma, target, alpha, gamma))
        u = torch.mean(self.edl_vacuity(logits))
        return loss + beta * u

    def get_priors(self):
        if isinstance(self.prior_fn, Callable):
            return self.prior_fn(self.K)
        elif isinstance(self.prior_fn, PriorOptions):
            return eval(self.prior_fn)(self.K)
        elif isinstance(self.prior_fn, float):
            return torch.tensor(self.prior_fn)
        else:
            raise ValueError


if __name__ == '__main__':
    pass
