import torch
import torch.nn as nn
import torch.nn.functional as F

class BRDLoss(nn.Module):
    """PyTorch version of `Masked Generative Distillation`

    Args:
        student_channels(int): Number of channels in the student's feature map.
        teacher_channels(int): Number of channels in the teacher's feature map.
        name (str): the loss name of the layer
        alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00007
        lambda_mgd (float, optional): masked ratio. Defaults to 0.5
    """

    def __init__(self,
                 student_emb,
                 teacher_emb,
                 alpha_mgd=0.00007,
                 lambda_mgd=0.15,
                 use_clip=True,
                 ):
        super(BRDLoss, self).__init__()
        self.alpha_mgd = alpha_mgd
        self.lambda_mgd = lambda_mgd

        if student_emb != teacher_emb:
            self.align = nn.Conv2d(student_emb, teacher_emb, kernel_size=1, stride=1, padding=0)
        else:
            self.align = None

        self.use_clip = use_clip

        self.generation = nn.Sequential(
            nn.Conv2d(teacher_emb, teacher_emb, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(teacher_emb, teacher_emb, kernel_size=3, padding=1))

    def forward(self,
                preds_S,
                preds_T):
        """Forward function.
        Args:
            preds_S(Tensor): Bs*D*H*W, student's feature map
            preds_T(Tensor): Bs*D, teacher's feature map
        """
        # assert preds_S.shape[-1:] == preds_T.shape[-1:]
        if self.align is not None:
            preds_S = self.align(preds_S)

        if self.use_clip:
            # preds_T = torch.clip(preds_T, preds_T.min(), preds_S.max())
            preds_T = preds_T / (preds_T.max()) * preds_S.max()

        loss = self.get_dis_loss(preds_S, preds_T) * self.alpha_mgd

        return loss

    def get_dis_loss(self, preds_S, preds_T):
        loss_mse = nn.MSELoss(reduction='sum')
        N, D, H, W = preds_S.shape

        device = preds_S.device
        mat = torch.rand((N, D, 1, 1)).to(device)
        # mat = torch.rand((N,1,H,W)).to(device)
        mat = torch.where(mat < self.lambda_mgd, 0, 1).to(device)

        masked_fea = torch.mul(preds_S, mat)
        new_fea = self.generation(masked_fea)
        new_fea = new_fea.flatten(2).mean(2)
        # print(new_fea.shape, preds_T.shape)

        dis_loss = loss_mse(new_fea, preds_T) / N

        return dis_loss

def get_logits_loss(fc_t, fc_s, one_hot_label, temp, num_classes=1000):
    s_input_for_softmax = fc_s / temp
    t_input_for_softmax = fc_t / temp

    softmax = torch.nn.Softmax(dim=1)
    logsoftmax = torch.nn.LogSoftmax()

    t_soft_label = softmax(t_input_for_softmax)

    softmax_loss = - torch.sum(t_soft_label * logsoftmax(s_input_for_softmax), 1, keepdim=True)

    fc_s_auto = fc_s.detach()
    fc_t_auto = fc_t.detach()
    log_softmax_s = logsoftmax(fc_s_auto)
    log_softmax_t = logsoftmax(fc_t_auto)
    # one_hot_label = F.one_hot(label, num_classes=num_classes).float()
    softmax_loss_s = - torch.sum(one_hot_label * log_softmax_s, 1, keepdim=True)
    softmax_loss_t = - torch.sum(one_hot_label * log_softmax_t, 1, keepdim=True)

    focal_weight = softmax_loss_s / (softmax_loss_t + 1e-7)
    ratio_lower = torch.zeros(1).cuda()
    focal_weight = torch.max(focal_weight, ratio_lower)
    focal_weight = 1 - torch.exp(- focal_weight)
    softmax_loss = focal_weight * softmax_loss

    soft_loss = (temp ** 2) * torch.mean(softmax_loss)

    return soft_loss


def kd_loss(logits_student, logits_teacher, temperature):
    """Original KD loss from temporal decoupling paper"""
    log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    
    loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
    loss_kd *= temperature ** 2
    
    return loss_kd


def make_teacher(avg_fr, labels):
    """
    Generate dynamic teacher labels from SNN outputs across time steps
    Args:
        avg_fr: Tensor of shape [T, B, C] - SNN outputs at each time step
        labels: Tensor of shape [B] or [B, C] - ground truth labels (may be one-hot or after mixup)
    Returns:
        teacher_labels: Tensor of shape [B, C] - dynamic teacher labels
    """
    predictions = avg_fr.argmax(dim=2)  # [T, B]
    
    # Handle different label formats
    if len(labels.shape) == 2:  # [B, C] - one-hot or mixup labels
        # Convert to class indices by taking argmax
        class_labels = labels.argmax(dim=1)  # [B]
    elif len(labels.shape) == 1:  # [B] - class indices
        class_labels = labels
    else:
        raise ValueError(f"Unexpected labels shape: {labels.shape}")
    
    # Create correct mask
    correct_mask = (predictions == class_labels.unsqueeze(0))  # [T, B]
    
    correct_avg_fr = avg_fr * correct_mask.unsqueeze(2)  # [T, B, C]
    
    correct_count = correct_mask.sum(dim=0).unsqueeze(1)  # [B, 1]
    
    epsilon = 1e-8
    correct_count = correct_count + epsilon
    
    teacher_labels = correct_avg_fr.sum(dim=0) / correct_count  # [B, C]
    return teacher_labels


def meom(logits_s, logits_t, labels, temp=3.0, alpha=0.5, beta=0.5):
    """
    Many Eyes, One Mind distillation loss - MEOM method
    Implements the distillation mechanism from snn_temporal_distillation
    
    Args:
        logits_s: Tensor of shape [T, B, C] - student (SNN) logits at each time step
        logits_t: Tensor of shape [B, C] - teacher (ANN) logits
        labels: Tensor of shape [B] or [B, C] - ground truth labels
        temp: float - distillation temperature
        alpha: float - weight for ANN distillation loss
        beta: float - weight for dynamic teacher distillation loss
    Returns:
        loss: scalar - combined temporal distillation loss
    """
    T, B, C = logits_s.shape
    
    # 1. Generate dynamic teacher labels from SNN outputs
    teacher_labels = make_teacher(logits_s, labels)
    
    # 2. Calculate distillation losses for each time step
    loss_time_dynamic = 0.0  # Loss with dynamic teacher
    loss_time_ann = 0.0      # Loss with ANN teacher
    
    # Teacher distributions (detached to prevent gradient flow)
    log_p_dynamic = F.log_softmax(teacher_labels / temp, dim=-1).detach()  # [B, C]
    p_dynamic = log_p_dynamic.exp()
    
    log_p_ann = F.log_softmax(logits_t / temp, dim=-1).detach()  # [B, C]
    p_ann = log_p_ann.exp()
    
    for t in range(T):
        # Student distribution at time step t
        log_q_s_t = F.log_softmax(logits_s[t] / temp, dim=-1)  # [B, C]
        
        # KL divergence with dynamic teacher
        kl_dynamic = torch.sum(p_dynamic * (log_p_dynamic - log_q_s_t), dim=-1)  # [B]
        loss_time_dynamic += kl_dynamic.mean()
        
        # KL divergence with ANN teacher
        kl_ann = torch.sum(p_ann * (log_p_ann - log_q_s_t), dim=-1)  # [B]
        loss_time_ann += kl_ann.mean()
    
    # Average across time steps
    loss_time_dynamic = (loss_time_dynamic / T) * (temp ** 2)
    loss_time_ann = (loss_time_ann / T) * (temp ** 2)
    
    # Combine losses
    total_loss = beta * loss_time_dynamic + alpha * loss_time_ann
    
    return total_loss


def meom_with_ce(logits_s, logits_t, labels, temp=3.0, alpha=0.5, beta=0.5, gamma=1.0):
    """
    Many Eyes, One Mind distillation loss with cross-entropy - enhanced MEOM method
    
    Args:
        logits_s: Tensor of shape [T, B, C] - student (SNN) logits at each time step
        logits_t: Tensor of shape [B, C] - teacher (ANN) logits
        labels: Tensor of shape [B] or [B, C] - ground truth labels
        temp: float - distillation temperature
        alpha: float - weight for ANN distillation loss
        beta: float - weight for dynamic teacher distillation loss
        gamma: float - weight for cross-entropy loss
    Returns:
        loss: scalar - combined loss with cross-entropy
    """
    T, B, C = logits_s.shape
    
    # 1. Temporal distillation loss
    distill_loss = meom(logits_s, logits_t, labels, temp, alpha, beta)
    
    # 2. Cross-entropy loss averaged across time steps
    ce_loss = 0.0
    for t in range(T):
        ce_loss += F.cross_entropy(logits_s[t], labels)
    ce_loss = ce_loss / T
    
    # 3. Combine losses
    total_loss = gamma * ce_loss + distill_loss
    
    return total_loss