
import numpy as np
import torch
import torch.nn as nn
import math
    
class DFALC(nn.Module):
    def __init__(self, params, conceptSize, roleSize, cEmb_init, rEmb_init, device, loss_weight, name="Godel", tnorm="godel", implication="R", lambda_hamacher=0.5, p_yager=2.0):
        super().__init__()
        self.params = params
        self.conceptSize, self.roleSize = conceptSize, roleSize
        self.device = device
        self.cEmb = nn.Parameter(torch.tensor(cEmb_init))
        self.rEmb = nn.Parameter(torch.tensor(rEmb_init))
        self.relu = torch.nn.ReLU()
        # self.c_mask, self.r_mask = self.get_mask()
        self.logic_name = name
        self.tnorm = tnorm.lower()  # T-norm类型: "godel", "product", "lukas", "yager", "hamacher"
        # implication 参数保留但不在HierarchyLoss中使用
        self.implication = implication.upper()  # 蕴含类型: "R" (Residual) 或 "S" (Standard)
        self.epsilon = 1e-2
        self.p = p_yager  # Yager T-norm参数p (p > 0)
        self.lambda_hamacher = lambda_hamacher  # Hamacher T-norm参数λ
        self.tau = 0.1  
        self.box_dim = 2  
        self.gamma = nn.Parameter(torch.tensor(0.5))  
        self.value_clamp = 1e-3
        self.loss_weight = loss_weight


    def to_sparse(self, A):
        return torch.sparse_coo_tensor(np.where(A!=0),A[np.where(A!=0)],A.shape)
    
    def index_sparse(self, A, idx):
        return torch.where(A.indices[0] in idx)
    
    def pi_0(self, x):
        return (1-self.epsilon)*x+self.epsilon
    
    def pi_1(self, x):
        return (1-self.epsilon)*x
    
    
    def neg(self, x, negf):
        negf = negf.unsqueeze(1)
        negf2 = negf*(-2) + 1

        
        return negf2*x
        
    def t_norm(self, x, y):
        # 使用指定的T-norm类型而不是logic_name
        if self.tnorm == "godel":
            return torch.minimum(x, y)  # Gödel T-norm
        
        elif self.tnorm == "product":
            return torch.mul(x, y)  # Product T-norm
        
        elif self.tnorm == "lukas":
            return torch.maximum(torch.add(x, y) - 1, torch.tensor(0.0, device=x.device))  # Łukasiewicz T-norm
        
        elif self.tnorm == "yager":
            # Yager T-norm
            p = self.p
            one = torch.tensor(1.0, device=x.device)
            
            # 对输入进行裁剪以避免数值问题
            x = torch.clamp(x, 1e-6, 1.0 - 1e-6)
            y = torch.clamp(y, 1e-6, 1.0 - 1e-6)
            
            # 计算 1 - ((1-x)^p + (1-y)^p)^(1/p)
            term = torch.pow(one - x, p) + torch.pow(one - y, p)
            term = torch.clamp(term, 1e-8, float('inf'))
            result = one - torch.pow(term, 1.0 / p)
            return torch.clamp(result, 0.0, 1.0)
        
        elif self.tnorm == "hamacher":
            # Hamacher T-norm with parameter λ
            # T_λ(x,y) = xy / (λ + (1-λ)(x+y-xy)) when λ ≥ 0
            epsilon = 1e-8
            x = torch.clamp(x, self.value_clamp, 1.0 - self.value_clamp)
            y = torch.clamp(y, self.value_clamp, 1.0 - self.value_clamp)
            
            numerator = x * y
            denominator = self.lambda_hamacher + (1 - self.lambda_hamacher) * (x + y - x * y) + epsilon
            return numerator / denominator
        
        # 对于特殊的logic_name，仍然保留原有逻辑
        elif self.logic_name == "LTN":
            return self.pi_0(x)*self.pi_0(y)
        elif self.logic_name == "Falcon":
            return torch.maximum(x + y - 1, torch.tensor(0.0, device=x.device))
        elif self.logic_name == "BoxEL":
            epsilon = 1e-3
            return ((1-epsilon)*x+epsilon)*((1-epsilon)*y+epsilon)
        elif self.logic_name == "Box2EL":
            return torch.maximum(x + y - 1, torch.tensor(0.0, device=x.device))
        elif self.logic_name == "ELEmbedding":
            x = torch.clamp(x, self.value_clamp, 1.0 - self.value_clamp)
            y = torch.clamp(y, self.value_clamp, 1.0 - self.value_clamp)
            numerator = x * y
            denominator = self.gamma + (1 - self.gamma) * (x + y - x * y)
            return numerator / (denominator + 1e-8)
        else:
            # 默认使用Gödel T-norm
            return torch.minimum(x, y)
    def t_cnorm(self, x, y):
        # 使用指定的T-conorm类型（T-norm的对偶）
        if self.tnorm == "godel":
            return torch.maximum(x, y)  # Gödel T-conorm
        
        elif self.tnorm == "product":
            return torch.add(x, y) - torch.mul(x, y)  # Product T-conorm
        
        elif self.tnorm == "lukas":
            return torch.minimum(torch.add(x, y), torch.tensor(1.0, device=x.device))  # Łukasiewicz T-conorm
        
        elif self.tnorm == "yager":
            # Yager T-conorm
            p = self.p
            one = torch.tensor(1.0, device=x.device)
            
            # 对输入进行裁剪以避免数值问题
            x = torch.clamp(x, 1e-6, 1.0 - 1e-6)
            y = torch.clamp(y, 1e-6, 1.0 - 1e-6)
            
            # 计算 (x^p + y^p)^(1/p)
            term = torch.pow(x, p) + torch.pow(y, p)
            term = torch.clamp(term, 1e-8, float('inf'))
            result = torch.pow(term, 1.0 / p)
            return torch.clamp(result, 0.0, 1.0)
        
        elif self.tnorm == "hamacher":
            # Hamacher T-conorm with parameter λ
            # S_λ(x,y) = (x+y-(2-λ)xy) / (1-(1-λ)xy) when λ ≥ 0
            epsilon = 1e-8
            x = torch.clamp(x, self.value_clamp, 1.0 - self.value_clamp)
            y = torch.clamp(y, self.value_clamp, 1.0 - self.value_clamp)
            
            numerator = x + y - (2 - self.lambda_hamacher) * x * y
            denominator = 1 - (1 - self.lambda_hamacher) * x * y + epsilon
            return numerator / denominator
        
        # 对于特殊的logic_name，仍然保留原有逻辑
        elif self.logic_name == "LTN":
            a = self.pi_1(x)
            b = self.pi_1(y)
            return a+b-a*b
        elif self.logic_name == "Falcon":
            return torch.minimum(x+y,torch.tensor(1.0, device=x.device))
        elif self.logic_name == "BoxEL":
            epsilon = 1e-3
            a = (1-epsilon)*x
            b = (1-epsilon)*y
            return a+b-a*b
        elif self.logic_name == "Box2EL":
            return torch.minimum(x+y,torch.tensor(1.0, device=x.device))
        elif self.logic_name == "ELEmbedding":
            x = torch.clamp(x, self.value_clamp, 1.0 - self.value_clamp)
            y = torch.clamp(y, self.value_clamp, 1.0 - self.value_clamp)
            numerator = x + y - (2 - self.gamma) * x * y
            denominator = 1 - (1 - self.gamma) * x * y
            return numerator / (denominator + 1e-8)
        else:
            # 默认使用Gödel T-conorm
            return torch.maximum(x, y)

    def forall(self, r, x):
        # 使用logic_name来决定forall的实现方式
        if self.logic_name in ["Godel", "Hierarchy", "Rule"]:
            return torch.min(self.t_cnorm(1-r,x.unsqueeze(1).expand(r.shape)),2).values
        elif self.logic_name == "LTN":
            return 1-torch.pow(torch.mean(torch.pow(1-self.pi_1(self.t_cnorm(r,x.unsqueeze(1).expand(r.shape))),self.p),2),1/self.p)
        elif self.logic_name == "Falcon":
            p_values = self.t_cnorm(r, x.unsqueeze(1).expand(r.shape))
            return torch.min(p_values, dim=1).values 
        elif self.logic_name == "BoxEL":
            # print("here: ",r,x) 
            values = torch.pow(torch.mean(torch.pow(1-self.pi_1(self.t_cnorm(r,x.unsqueeze(1).expand(r.shape))),self.p),2),1/self.p)
            return 1-values
        elif  self.logic_name == "Box2EL":
            p_values = self.t_cnorm(r, x.unsqueeze(1).expand(r.shape))
            return torch.min(p_values, dim=1).values
        elif self.logic_name == "ELEmbedding":
            expanded_x = x.unsqueeze(1).expand(r.shape)
            cnorm_values = self.t_cnorm(1 - r, expanded_x)
            return torch.min(cnorm_values, dim=1).values
    
    def exist(self, r, x):
        # 使用logic_name来决定exist的实现方式
        if self.logic_name in ["Godel", "Hierarchy", "Rule"]:
            return torch.max(self.t_norm(r,x.unsqueeze(1).expand(r.shape)),2).values
        elif self.logic_name == "LTN":
            return torch.pow(torch.mean(torch.pow(self.pi_0(self.t_norm(r,x.unsqueeze(1).expand(r.shape))),self.p),2),1/self.p)
        elif self.logic_name == "Falcon":
            p_values = self.t_norm(r, x.unsqueeze(1).expand(r.shape))
            return torch.max(p_values, dim=1).values  
        elif self.logic_name == "BoxEL":
            return torch.pow(torch.mean(torch.pow(self.pi_0(self.t_norm(r,x.unsqueeze(1).expand(r.shape))),self.p),2),1/self.p)
        elif self.logic_name == "Box2EL":
            p_values = self.t_norm(r, x.unsqueeze(1).expand(r.shape))
            return torch.max(p_values, dim=1).values 
        elif self.logic_name == "ELEmbedding":
            expanded_x = x.unsqueeze(1).expand(r.shape)
            tnorm_values = self.t_norm(r, expanded_x)
            return torch.max(tnorm_values, dim=1).values 
    def L2(self, x, dim=1):
        return torch.sqrt(torch.sum((x)**2, dim))
    
    def L2_dist(self, x, y, dim=1):
        return torch.sqrt(torch.sum((x-y)**2, dim))
    
    def L1(self,x,dim=1):
        return torch.sum(torch.abs(x),dim)
    
    def L1_dist(self,x,y,dim=1):
        return torch.sum(torch.abs(x-y),dim)
    
    def compute_implication(self, x, y):
        """
        实现蕴含操作
        R蕴含 (Residual): 基于当前T-norm的剩余蕴含
        S蕴含 (Standard): x → y = ¬x ∨ y = max(1-x, y)
        """
        if self.implication == "R":
            # R蕴含：基于当前T-norm的剩余蕴含
            if self.tnorm == "godel":
                # Gödel蕴含: x → y = 1 if x ≤ y else y
                return torch.where(x <= y, torch.ones_like(x), y)
                
            elif self.tnorm == "product":
                # Product蕴含: x → y = min(1, y/x) when x > 0, else 1
                # 数值稳定版本：避免除法导致的不稳定
                epsilon = 1e-3  # 更大的epsilon确保稳定性
                x_safe = torch.clamp(x, epsilon, 1.0 - epsilon)
                y_safe = torch.clamp(y, epsilon, 1.0 - epsilon)
                
                # 使用更稳定的计算方式
                ratio = y_safe / x_safe
                # 进一步限制比值范围，防止极端值
                ratio_clamped = torch.clamp(ratio, 0.0, 1e3)  # 限制最大值
                result = torch.clamp(ratio_clamped, 0.0, 1.0)
                return result
                                 
            elif self.tnorm == "lukas":
                # Łukasiewicz蕴含: x → y = min(1, 1 - x + y)
                return torch.clamp(1 - x + y, 0.0, 1.0)
                
            elif self.tnorm == "yager":
                # Yager蕴含：数值稳定版本
                p = self.p
                epsilon = 1e-3  # 更大的epsilon确保稳定性
                x = torch.clamp(x, epsilon, 1.0 - epsilon)
                y = torch.clamp(y, epsilon, 1.0 - epsilon)
                
                if p == 1.0:  # 特殊情况，接近Łukasiewicz
                    return torch.clamp(1 - x + y, 0.0, 1.0)
                elif p >= 10.0:  # 大p值接近Gödel
                    return torch.where(x <= y, torch.ones_like(x), y)
                else:
                    # 数值稳定的Yager蕴含计算
                    # 防止pow运算中的数值问题
                    x_complement = torch.clamp(1 - x, epsilon, 1.0 - epsilon)
                    y_complement = torch.clamp(1 - y, epsilon, 1.0 - epsilon)
                    
                    term1 = torch.pow(y_complement, p)
                    term2 = torch.pow(x_complement, p)
                    
                    # 安全的减法和开根号
                    diff = term1 - term2
                    positive_diff = torch.clamp(diff, 0.0, 1e6)  # 限制极端值
                    
                    if torch.all(positive_diff == 0):
                        return torch.ones_like(x)
                    else:
                        # 数值稳定的开根号
                        root_term = torch.pow(positive_diff + epsilon, 1.0 / p)
                        result = 1 - torch.clamp(root_term, 0.0, 1.0)
                        return torch.clamp(result, 0.0, 1.0)
                    
            elif self.tnorm == "hamacher":
                # Hamacher蕴含的正确公式
                # 对于Hamacher T-norm: T_λ(x,y) = xy/(λ+(1-λ)(x+y-xy))
                # Hamacher T-norm的R蕴含: x →_H y = y / (λ + (1-λ)x)
                epsilon = 1e-3  # 更大的epsilon确保稳定性
                x_safe = torch.clamp(x, epsilon, 1.0 - epsilon)
                y_safe = torch.clamp(y, epsilon, 1.0 - epsilon)
                
                denominator = self.lambda_hamacher + (1 - self.lambda_hamacher) * x_safe
                # 确保分母不会太小
                denominator = torch.clamp(denominator, epsilon, float('inf'))
                
                ratio = y_safe / denominator
                return torch.clamp(ratio, 0.0, 1.0)
                
            else:
                # 默认使用Gödel蕴含
                return torch.where(x <= y, torch.ones_like(x), y)
                
        elif self.implication == "S":
            # S蕴含 (Standard): x → y = ¬x ∨ y = max(1-x, y)
            return torch.maximum(1 - x, y)
        else:
            raise ValueError(f"Unknown implication type: {self.implication}")
    
    def HierarchyLoss(self, lefte, righte):
        """
        基于层次关系的损失函数
        根据self.implication选择R蕴含或S蕴含，配合不同t-norm进行性能比较
        
        Args:
            lefte: 左侧概念A的真值向量  
            righte: 右侧概念B的真值向量
        """
        # 对于Hierarchy模型，使用蕴含；其他模型使用原始ReLU
        if self.logic_name != "Hierarchy":
            # 非Hierarchy模型保持原始ReLU实现
            if len(lefte.shape) > 1:
                return torch.mean(self.L1(self.relu(lefte-righte)))
            else:
                return torch.mean(self.relu(lefte-righte))
        
        # Hierarchy模型：使用蕴含方法
        impl_result = self.compute_implication(lefte, righte)
        impl_loss = 1 - impl_result
        
        if len(lefte.shape) > 1:
            return torch.mean(self.L1(impl_loss))
        else:
            return torch.mean(impl_loss)

    def rule_based_loss(self, lefte, atype, righte, left, right, negf):
        loss = 0
        #print(atype)
        size = lefte.shape[-1]
        if atype == 0:                
            loss = torch.mean(self.L1((1-righte)*self.relu(lefte-righte)))            
                
        elif atype == 1:              
            loss = torch.mean(self.L1((1-righte)*self.relu(lefte-righte)))

        elif atype == 2:
            loss = torch.mean(self.L1((1-righte)*self.relu(lefte-righte)))

        elif atype == 3:
            loss = torch.mean(self.L1(torch.sum((1-self.rEmb[right[:,0]])*self.relu(self.t_norm(lefte.unsqueeze(-1).repeat(1,1,size),righte.unsqueeze(-2).repeat(1,size,1))-self.rEmb[right[:,0]]),dim=1)     +      lefte*self.relu(lefte-0.2)*self.relu(0.8-torch.sum(self.t_norm(righte.unsqueeze(-2).repeat(1,size,1),self.rEmb[right[:,0]]),dim=2))))
            #if loss>0.1 :
            #    print(atype)
            #A (1-righte)*self.relu(0.8-righte)*self.relu(torch.sum(torch.minimum(self.rEmb[right[:,0]],lefte.unsqueeze(-1).repeat(1, 1, size)),dim=1)-0.8)
            #B torch.sum((1-self.rEmb[right[:,0]])*self.relu(torch.minimum(lefte.unsqueeze(-1).repeat(1,1,size),righte.unsqueeze(-2).repeat(1,size,1))-self.rEmb[right[:,0]]),dim=1) 
            #C lefte*self.relu(lefte-0.2)*self.relu(0.8-torch.sum(torch.minimum(righte.unsqueeze(-2).repeat(1,size,1),self.rEmb[right[:,0]]),dim=2))                         
            #rule1 A+C
            #rule2 B+C (this!)
            #rule3 C

        
        elif atype == 4:                
            loss = torch.mean(self.L1((1-righte)*self.relu(0.8-righte)*self.relu(torch.sum(self.t_norm(lefte.unsqueeze(-1).repeat(1, 1, size),self.rEmb[right[:,0]]),dim=1)-0.8)))
            #if loss>0.1 :
            #    print(atype)
                
        elif atype == 5:              
            loss = torch.mean(self.L1((1-righte)*self.relu(0.8-righte)*self.relu(torch.sum(self.t_norm(lefte.unsqueeze(-2).repeat(1,size,1),self.rEmb[left[:,0]]),dim=2)-0.8)))
            #if loss>0.1 :
            #    print(atype)
            
        elif atype == 6:
            loss = torch.mean(self.L1((1-righte)*self.relu(0.8-righte)*self.relu(torch.sum(self.t_norm(lefte.unsqueeze(-2).repeat(1,size,1),self.rEmb[left[:,0]]),dim=2)-0.8) + (1-lefte)*self.relu(0.8-lefte)*self.relu(torch.sum(self.t_norm(righte.unsqueeze(-1).repeat(1, 1, size),self.rEmb[left[:,0]]),dim=1)-0.8)))
            #if loss>0.1 :
            #    print(atype)


        return loss
        
        
        

    def forward(self, batch, atype, device):
        left, right, negf = batch
        
        loss, lefte, righte, b_c_mask, b_r_mask = None, None, None, None, None
        
        self.cEmb[-1,:].detach().masked_fill_(self.cEmb[-1,:].gt(0.0),1.0)
        self.cEmb[-2,:].detach().masked_fill_(self.cEmb[-2,:].lt(1),0.0)
        
        
        if atype == 0:
            lefte = self.neg(self.cEmb[left],-negf[:,0])
            righte = self.neg(self.cEmb[right],negf[:,1])
            shape = lefte.shape
            
        elif atype == 1:
            righte = self.neg(self.cEmb[right], negf[:,2])
            shape = righte.shape
            lefte = self.t_norm(self.neg(self.cEmb[left[:,0]],negf[:,0]), self.neg(self.cEmb[left[:,1]],negf[:,1]))

        elif atype == 2:
            lefte = self.neg(self.cEmb[left], negf[:,0])
            shape = lefte.shape
            righte = self.t_norm(self.neg(self.cEmb[right[:,0]],negf[:,1]), self.neg(self.cEmb[right[:,1]],negf[:,2]))

        elif atype == 3:
            lefte = self.neg(self.cEmb[left], negf[:,0])
            shape = lefte.shape
            righte = self.exist(self.rEmb[right[:,0]], self.neg(self.cEmb[right[:,1]],negf[:,1]))

        elif atype == 4:
            lefte = self.neg(self.cEmb[left], negf[:,0])
            shape = lefte.shape
            righte = self.forall(self.rEmb[right[:,0]],self.neg(self.cEmb[right[:,1]], negf[:,1]))
            
            
        elif atype == 5:
            righte = self.neg(self.cEmb[right], negf[:,1])
            shape = righte.shape
            lefte = self.exist(self.rEmb[left[:,0]],self.neg(self.cEmb[left[:,1]], negf[:,0]))

        elif atype == 6:
            righte = self.neg(self.cEmb[right], negf[:,1])
            shape = righte.shape
            lefte = self.forall(self.rEmb[left[:,0]],self.neg(self.cEmb[left[:,1]], negf[:,0]))

        if self.logic_name == "Rule":        
            loss = self.rule_based_loss(lefte, atype, righte, left, right, negf)
        elif self.logic_name == "Combined":
            loss = self.loss_weight*self.HierarchyLoss(lefte, righte) + (1-self.loss_weight)*self.rule_based_loss(lefte, atype, righte, left, right, negf)
        elif self.logic_name == "Hierarchy":
            loss = self.HierarchyLoss(lefte, righte)
        else:
            # 默认使用HierarchyLoss
            loss = self.HierarchyLoss(lefte, righte)
        
        
          
        return loss