import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist

def softplus_evidence(logits):
    return F.softplus(logits)

def get_edl_results(logits):
    evidence = softplus_evidence(logits)
    alpha = evidence + 1.0
    S = torch.sum(alpha, dim=1, keepdim=True)
    K = logits.shape[1]
    uncertainty = K / S
    probs = alpha / S
    return evidence, alpha, uncertainty, probs

class FusionStrategies:
    def __init__(self):
        self.normal = dist.Normal(0, 1)


    def saef(self, logits_list):

        z_weighted_sum = 0
        w_squared_sum = 0
        target_scale = 5.0
        
        for logits in logits_list:

            mean = logits.mean(dim=1, keepdim=True)
            std = logits.std(dim=1, keepdim=True)
            
            std_bounded = torch.clamp(std, min=1e-2)

            logits_safe = ((logits - mean) / (std_bounded + 1e-8)) * target_scale
            

            evidence, alpha, uncertainty, probs = get_edl_results(logits_safe)

            w = (1.0 - uncertainty) + 1e-6

            p_val = torch.clamp(probs, min=1e-6, max=1-1e-6)
            z = self.normal.icdf(p_val)

            if isinstance(z_weighted_sum, int):
                z_weighted_sum = w * z
            else:
                z_weighted_sum += w * z
            w_squared_sum += w ** 2
            
        z_fused = z_weighted_sum / torch.sqrt(w_squared_sum + 1e-8)
        evidence_fused = F.softplus(z_fused)

        alpha_fused = evidence_fused + 1.0
        S_fused = torch.sum(alpha_fused, dim=1, keepdim=True)
        
        p_fused = alpha_fused / S_fused
        return p_fused, alpha_fused


    def sum_fusion(self, logits_list):
        total_evidence = 0
        for logits in logits_list:
            total_evidence += softplus_evidence(logits)
        alpha_fused = total_evidence + 1
        S_fused = torch.sum(alpha_fused, dim=1, keepdim=True)
        return alpha_fused / S_fused, alpha_fused


    def mean_fusion(self, logits_list):
        total_evidence = 0
        for logits in logits_list:
            total_evidence += softplus_evidence(logits)
        mean_evidence = total_evidence / len(logits_list)
        alpha_fused = mean_evidence + 1
        S_fused = torch.sum(alpha_fused, dim=1, keepdim=True)
        return alpha_fused / S_fused, alpha_fused


    def weighted_fusion(self, logits_list):

        weighted_evidence_sum = 0
        for logits in logits_list:
            evidence, alpha, uncertainty, _ = get_edl_results(logits)
            
            weight = (1.0 - uncertainty) + 1e-8
            weighted_evidence_sum += weight * evidence
            
        alpha_fused = weighted_evidence_sum + 1.0
        S_fused = torch.sum(alpha_fused, dim=1, keepdim=True)
        return alpha_fused / S_fused, alpha_fused


    def ds_fusion(self, logits_list):
        def ds_combine_two(alpha1, alpha2):
            K = alpha1.size(1)
            S1 = torch.sum(alpha1, dim=1, keepdim=True)
            S2 = torch.sum(alpha2, dim=1, keepdim=True)
            
            b1 = (alpha1 - 1) / S1
            b2 = (alpha2 - 1) / S2
            u1 = K / S1
            u2 = K / S2
            
            b1_b2_dot = torch.sum(b1 * b2, dim=1, keepdim=True)
            C = (1 - u1) * (1 - u2) - b1_b2_dot
            scale = 1.0 / (1 - C + 1e-8)
            
            b_new = (b1 * b2 + b1 * u2 + b2 * u1) * scale
            u_new = (u1 * u2) * scale
            
            S_new = K / (u_new + 1e-8)
            e_new = b_new * S_new
            alpha_new = e_new + 1
            return alpha_new

        current_alpha = None
        for logits in logits_list:
            _, alpha, _, _ = get_edl_results(logits)
            if current_alpha is None:
                current_alpha = alpha
            else:
                current_alpha = ds_combine_two(current_alpha, alpha)
        
        S_final = torch.sum(current_alpha, dim=1, keepdim=True)
        return current_alpha / S_final, current_alpha

class UnifiedFusion(nn.Module):
    def __init__(self, method='stouffer'):
        super(UnifiedFusion, self).__init__() 
        self.strategies = FusionStrategies()
        self.method = method.lower()

        self.register_buffer('dummy', torch.tensor(0.0))

    def forward(self, logits_list):

        if self.method == 'ds':
            probs, alpha = self.strategies.ds_fusion(logits_list)
        elif self.method == 'sum':
            probs, alpha = self.strategies.sum_fusion(logits_list)
        elif self.method == 'mean':
            probs, alpha = self.strategies.mean_fusion(logits_list)
        elif self.method == 'ours':
            probs, alpha = self.strategies.saef(logits_list)
        elif self.method == 'weighted':
            probs, alpha = self.strategies.weighted_fusion(logits_list)
        else:
            raise ValueError(f"Unknown fusion method: {self.method}")
        
        return probs, alpha

def relu_evidence(y):
    return F.relu(y)


