import torch
import torch.nn as nn
import torch.nn.functional as F


def KL(alpha, c, yita):
    yita = yita
    S_alpha = torch.sum(alpha, dim=1, keepdim=True)
    S_beta = torch.sum(yita, dim=1, keepdim=True)
    lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
    lnB_uni = torch.sum(torch.lgamma(yita), dim=1, keepdim=True) - torch.lgamma(S_beta)
    dg0 = torch.digamma(S_alpha)
    dg1 = torch.digamma(alpha)
    kl = torch.sum((alpha - yita) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni
    return kl

def get_e_sigma_loss(y, evidence_a):
    device = evidence_a.device
    num_classes = evidence_a.shape[1]
    y_one_hot = F.one_hot(y, num_classes=num_classes).to(device)
    evidence_for_true_class = evidence_a * y_one_hot.float()
    sum_target_evidence = torch.sum(evidence_for_true_class, dim=0)
    mean_target_evidence =  sum_target_evidence / (torch.sum(y_one_hot, dim=0) + 1e-8)

    e_hat = torch.sum(mean_target_evidence,dim=0)/len(mean_target_evidence)
    mean_target_evidence_new = mean_target_evidence - e_hat
    fenzi = torch.var(mean_target_evidence_new)
    fenmu = e_hat
    loss = fenzi/(fenmu*fenmu)
    return loss

def con_loss(alpha1, alpha2, c):
    S1 = torch.sum(alpha1, dim=1, keepdim=True)
    S2 = torch.sum(alpha2, dim=1, keepdim=True)
    p1 = alpha1 / S1
    p2 = alpha2 / S2
    u1 = c / S1
    u2 = c / S2
    var1 = p1 * (1 - p1) * u1 / (c + u1)
    var2 = p2 * (1 - p2) * u2 / (c + u2)
    return torch.sum(torch.abs(var1 - var2), dim=1, keepdim=True)

def get_con_loss_v1(evidences, yitas):
    loss2 = 0
    V = len(evidences)
    num_classes = evidences[0][0].shape[0]
    for v in range(V):
        for m in range(v, V):
            alpha1 = evidences[v] + yitas[v]
            alpha2 = evidences[m] + yitas[m]
            loss2 += con_loss(alpha1, alpha2, num_classes)
    loss2 = (1 / (V - 1)) * loss2
    loss2 = loss2.mean()
    return loss2

def ce_loss_train(p, alpha, c, yita, global_step, annealing_step, num_corrects, nums):
    S = torch.sum(alpha, dim=1, keepdim=True)
    E = alpha - yita
    label = F.one_hot(p, num_classes=c)
    A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
    annealing_coef = min(1, global_step / annealing_step)
    alp = E * (1 - label) + yita
    B = annealing_coef * KL(alp, c, yita)
    AB = A + B
    per_class_loss = torch.zeros(c, device=AB.device)
    for cls in range(c):
        cls_mask = (p == cls)
        if cls_mask.sum() > 0:
            cls_loss = AB[cls_mask].sum() * (1/(nums[cls]+1))
        else:
            cls_loss = torch.tensor(0.0, device=AB.device)
        per_class_loss[cls] = cls_loss

    final_loss = per_class_loss.mean()
    return final_loss


def ce_loss_warmup(p, alpha, c, yita, global_step, annealing_step):
    S = torch.sum(alpha, dim=1, keepdim=True)
    E = alpha - yita
    label = F.one_hot(p, num_classes=c)
    A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
    annealing_coef = min(1, global_step / annealing_step)
    alp = E * (1 - label) + yita
    B = annealing_coef * KL(alp, c, yita)
    AB = A + B
    return AB.mean()



class TMC(nn.Module):

    def __init__(self, classes, views, classifier_dims, beta, lambda_epochs=1):
        super(TMC, self).__init__()
        self.beta = beta
        self.views = views
        self.classes = classes
        self.lambda_epochs = lambda_epochs
        self.Classifiers = nn.ModuleList([Classifier(classifier_dims[i], self.classes) for i in range(self.views)])

    def get_confidece(self, evidences):
        evidences = evidences.to(dtype=torch.float64)
        sum_evidence = torch.sum(evidences, dim=-1)
        confidence = sum_evidence / (self.classes + sum_evidence)
        return confidence

    def WBF(self, alpha, yitas, yita_a):
        ev_list = [alpha[v] - yitas[v] for v in range(self.views)]
        evs = torch.stack(ev_list, dim=0)
        evs = evs.permute(1, 0, 2)
        cs = self.get_confidece(evs)
        denominator = torch.sum(cs, dim=-1, keepdim=True)
        epsilon = 1e-8
        denominator = denominator + epsilon * (denominator == 0).float()
        weight = cs / denominator
        weight = torch.permute(weight, (1, 0)).unsqueeze(-1)
        evs = torch.permute(evs, (1, 0, 2))
        evidence_a = torch.zeros_like(evs[0])
        for v in range(self.views):
            evidence_a = (evidence_a + evs[v] * weight[v])
        alpha_a = evidence_a + yita_a
        return alpha_a



    def forward(self, X, y, yitas, yita_a, global_step, num_corrects=None, num_corrects_per_view=None, nums=None):
        evidence = self.infer(X)
        loss = 0
        loss_e_sigma = 0
        alpha = dict()
        for v_num in range(len(X)):
            alpha[v_num] = evidence[v_num] + yitas[v_num]
            loss_e_sigma += get_e_sigma_loss(y, evidence[v_num])
            if num_corrects == None and num_corrects_per_view == None and nums == None:
                loss += ce_loss_warmup(y, alpha[v_num], self.classes, yitas[v_num], global_step, self.lambda_epochs)
            else:
                loss += ce_loss_train(y, alpha[v_num], self.classes, yitas[v_num], global_step, self.lambda_epochs,num_corrects_per_view[v_num], nums)
        loss_con = get_con_loss_v1(evidence, yitas) #0.07
        alpha_a = self.WBF(alpha, yitas, yita_a)
        evidence_a = alpha_a - yita_a
        loss_e_sigma += get_e_sigma_loss(y, evidence_a)
        if num_corrects == None and num_corrects_per_view == None and nums == None:
            loss +=  ce_loss_warmup(y, alpha_a, self.classes, yita_a, global_step, self.lambda_epochs)
        else:
            loss +=  ce_loss_train(y, alpha_a, self.classes, yita_a, global_step, self.lambda_epochs, num_corrects, nums)

        loss_total = loss +  loss_con * self.beta + loss_e_sigma * 0.1
        return evidence, evidence_a, loss_total, loss, loss_con, loss_e_sigma

    def infer(self, input):
        evidence = dict()
        for v_num in range(self.views):
            evidence[v_num] = self.Classifiers[v_num](input[v_num])
        return evidence


class Classifier(nn.Module):
    def __init__(self, classifier_dims, classes):
        super(Classifier, self).__init__()
        self.num_layers = len(classifier_dims)
        self.fc = nn.ModuleList()
        for i in range(self.num_layers-1):
            self.fc.append(nn.Linear(classifier_dims[i], classifier_dims[i+1]))

        self.fc.append(nn.Linear(classifier_dims[self.num_layers-1], classes))
        self.fc.append(nn.Softplus())

    def forward(self, x):
        h = self.fc[0](x)
        for i in range(1, len(self.fc)):
            h = self.fc[i](h)
        return h
