import torch.nn as nn
import torch.nn.functional as F
import torch 
import torch.autograd as autograd
import math
class BCEWithLogitsLossWithIgnoreIndex(nn.Module):
    def __init__(self, reduction='mean', ignore_index=255):
        super().__init__()
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, inputs, targets, weight=None):
        n_cl = torch.tensor(inputs.shape[1]).to(inputs.device)
        labels_new = torch.where(targets != self.ignore_index, targets, n_cl)
        targets = F.one_hot(labels_new, inputs.shape[1] + 1).float().permute(0, 3, 1, 2)
        targets = targets[:, :inputs.shape[1], :, :] 

        loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        loss = loss.sum(dim=1)  
        
        if weight is not None:
            loss = loss * 3
        
        if self.reduction == 'mean':
            return torch.masked_select(loss, targets.sum(dim=1) != 0).mean()
        elif self.reduction == 'sum':
            return torch.masked_select(loss, targets.sum(dim=1) != 0).sum()
        else:
            return loss * targets.sum(dim=1)


class MinLossAndProjection(nn.Module):
    def __init__(self, reduction='mean', alpha=1., kd_cil_weights=False):
        super().__init__()
        self.reduction = reduction
        self.alpha = alpha
        self.kd_cil_weights = kd_cil_weights

    def forward(self, inputs, targets, masks=None):
        #Parameter-driven minimization of old knowledge distribution
        outputs = torch.log_softmax(inputs, dim=1)
        labels = torch.softmax(targets * self.alpha, dim=1)
        new_labels = torch.zeros_like(outputs)
        new_labels[:, :labels.shape[1]] = labels
        unique_labels = torch.unique(masks)
        for i in range(len(unique_labels)):
            if unique_labels[i] != 0 and unique_labels[i] != 255:
                mask = torch.where(masks == unique_labels[i], 1, 0)
                new_labels[:, unique_labels[i]] = mask * labels[:, 0]
                new_labels[:, 0] = (1 - mask) * new_labels[:, 0]

        dot_product = (new_labels * outputs).sum(dim=1, keepdim=True)  # shape: [B, 1, H, W]
        b_norm_squared = (outputs * outputs).sum(dim=1, keepdim=True)  # shape: [B, 1, H, W]
        eps = 1e-8
        b_norm_squared = b_norm_squared + eps
        #quantify the projection strength
        confidence_map = dot_product / b_norm_squared
        projection = confidence_map * new_labels  # shape: [B, C, H, W]
        #calculate L_pro
        loss_projection = (1-projection.mean(dim=1)/outputs.mean(dim=1))
        #calculate L_min
        loss =  (outputs * new_labels).mean(dim=1)

        if self.kd_cil_weights:
            w = -(torch.softmax(targets, dim=1) * torch.log_softmax(targets, dim=1)).sum(dim=1) + 1.0
            loss = loss * w[:, None]
        if self.reduction == 'mean':
            outputs = -torch.mean(loss) + 0.00001 * torch.mean(loss_projection) 
        elif self.reduction == 'sum':
            outputs = -torch.sum(loss)
        else:
            outputs = -loss

        return outputs

 

class CKDAndEstimationLoss(nn.Module):
    def __init__(self, reduction='mean', alpha=1, beta=0.5):
        super().__init__()
        self.reduction = reduction
        self.alpha = alpha
        self.beta = beta  

    def forward(self, inputs, targets, weights=1):
        loss = (inputs - targets) ** 2
        loss = loss * weights * self.alpha  
        inputs_grad = autograd.grad(loss.sum(), inputs, create_graph=True)[0] 
        #identify low-curvature regions 
        position_map = autograd.grad(inputs_grad.sum(), inputs, create_graph=True)[0]  
        Lap_medium = inputs - self.beta * position_map  
        loss_lap = (inputs - Lap_medium) ** 2
        loss_lap = loss_lap * weights * self.alpha
        #calculate L_lap
        # total_loss = loss + 0.4*loss_lap
        total_loss = loss + 1.0*loss_lap

        if self.reduction == 'mean':
            if torch.is_tensor(weights):
                mask = torch.where(weights > 0, 1, 0)
                count = torch.sum(mask.expand_as(loss))
                return torch.sum(total_loss) / count
            elif weights == 1:
                return torch.mean(total_loss)

        elif self.reduction == 'sum':
            return torch.sum(total_loss)
        else:
            return total_loss


class EntropyInducedLoss(nn.Module):
    def __init__(self, reduction='mean', classes=21, lambda_mi=0.001, lambda_kl=0.001):
        super().__init__()
        self.reduction = reduction
        self.num_classes = classes  
        self.lambda_mi = lambda_mi     
        self.lambda_kl = lambda_kl    
    
    def forward(self, class_token, weight=1.):  
        #Entropy-induced optimization of overlap between new and old knowledge distribution    
        class_token = class_token / class_token.norm(dim=-1, keepdim=True)
        class_token_sim = torch.matmul(class_token, class_token.permute(0, 2, 1).detach())  # [B, C, C]
        for i in range(len(class_token_sim)):
            class_token_sim[i].fill_diagonal_(0)
        class_token_sim[:, :self.num_classes[0]] = 0  
        non_zero_mask = class_token_sim != 0
        loss_orth = class_token_sim[non_zero_mask].abs().mean()
        prob = F.softmax(class_token, dim=1)  # [B, C, F]
        # get  marginal probability
        marginal_prob = prob.mean(dim=0)  # [C, F]
        #marginal entropy: reflects a well-balanced category distribution in the new feature space
        marginal_entropy = -torch.sum(marginal_prob * torch.log(marginal_prob + 1e-6))
        #conditional entropy: measures the uncertainty of new knowledge given the old
        conditional_entropy = -torch.sum(prob * torch.log(prob + 1e-6), dim=1).mean()
        # calculate L_max
        mi_loss = (marginal_entropy - conditional_entropy) / math.log(self.num_classes[0])
        uniform_prob = torch.full_like(prob, 1.0 / self.num_classes[0]) 
        kl_loss = F.kl_div(prob.log(), uniform_prob, reduction='batchmean') 
        total_loss =torch.relu( loss_orth * weight - self.lambda_mi * mi_loss + self.lambda_kl * kl_loss)

        if self.reduction == 'mean':
            return torch.mean(total_loss)
        elif self.reduction == 'sum':
            return torch.sum(total_loss)
        else:
            return total_loss


class KnowledgeDistillationLoss(nn.Module):
      
    def __init__(self, reduction='mean', alpha=1., kd_cil_weights=False):
        super().__init__()
        self.reduction = reduction
        self.alpha = alpha
        self.kd_cil_weights = kd_cil_weights

    def forward(self, inputs, targets, masks=None):
        inputs = inputs.narrow(1, 0, targets.shape[1])
        outputs = torch.log_softmax(inputs, dim=1)
        labels = torch.softmax(targets * self.alpha, dim=1)
        
        loss = (outputs * labels).mean(dim=1)
        if self.kd_cil_weights:
            w = -(torch.softmax(targets, dim=1) * torch.log_softmax(targets, dim=1)).sum(dim=1) + 1.0
            loss = loss * w[:, None]

        if self.reduction == 'mean':
            outputs = -torch.mean(loss)
        elif self.reduction == 'sum':
            outputs = -torch.sum(loss)
        else:
            outputs = -loss

        return outputs


