import torch
from torch import Tensor, nn
from torch.nn import functional as F


def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
    flag = False
    if target.dim() == lprobs.dim() - 1:
        flag = True
        target = target.unsqueeze(-1)

    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.0)
        smooth_loss.masked_fill_(pad_mask, 0.0)

    if flag:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)

    if reduce:
        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
    eps_i = epsilon / (lprobs.size(-1) - 1)
    loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
    return loss, nll_loss


class CrossEntropyLoss(nn.CrossEntropyLoss):
    def forward(self, scores: Tensor, target: Tensor, mask=None) -> Tensor:
        """
          scores: [N, ..., C], unnormalized scores
          target: [N, ...]
          mask: [N, ...], where elements with `True` are allowed and `False` are masked-out
        """
        n_tokens = target.numel()
        n_nonpad_tokens = target.ne(self.ignore_index).long().sum()

        bsz, num_classes = scores.shape[0], scores.shape[-1]

        if mask is not None:
            scores = scores[mask]  # [N * len, C]
            target = target[mask]  # [N]
        scores = scores.reshape(-1, num_classes)
        target = target.reshape(-1)

        if self.ignore_index is not None:
            sample_size = target.ne(self.ignore_index).long().sum()
        else:
            sample_size = torch.tensor(target.numel(), device=target.device)

        # smooth_loss = F.cross_entropy(
        #     scores.transpose(1, -1), target,
        #     weight=self.weight,
        #     ignore_index=self.ignore_index, reduction=self.reduction,
        #     label_smoothing=self.label_smoothing)

        loss, nll_loss = label_smoothed_nll_loss(
            lprobs=F.log_softmax(scores, dim=-1),
            target=target,
            epsilon=self.label_smoothing,
            ignore_index=self.ignore_index,
            reduce=True,
        )
        loss_avg = loss / sample_size
        ppl = torch.exp(nll_loss / sample_size)

        logging_output = {
            'nll_loss_sum': nll_loss.data,
            'loss_sum': loss.data,
            'ppl': ppl.data,
            'bsz': bsz,
            'sample_size': sample_size,
            'sample_ratio': sample_size / n_tokens,
            'nonpad_ratio': n_nonpad_tokens / n_tokens
        }
        return loss_avg, logging_output


class Coord2SeqCrossEntropyLoss(nn.CrossEntropyLoss):
    def forward(self, scores: Tensor, target: Tensor, label_mask=None, coord_mask=None, weights=None) -> Tensor:
        """
          scores: [N, L, C], unnormalized scores
          target: [N, L]
          coord_mask: FloatTensor [N, L], where elements with `True` are allowed and `False` are masked-out
        """
        if label_mask is None:
            label_mask = coord_mask

        bsz, num_classes = scores.shape[0], scores.shape[-1]

        n_tokens = target.numel()
        if self.ignore_index is not None:
            sample_size = n_nonpad_tokens = target.ne(self.ignore_index).float().sum()
        else:
            sample_size = n_nonpad_tokens = n_tokens

        # [N, L]
        loss, nll_loss = label_smoothed_nll_loss(
            lprobs=F.log_softmax(scores, dim=-1),
            target=target,
            epsilon=self.label_smoothing,
            ignore_index=self.ignore_index,
            reduce=False,
        )
        if weights is not None:
            loss, nll_loss = loss * weights, nll_loss * weights
        fullseq_loss = loss.sum() / sample_size
        fullseq_nll_loss = nll_loss.sum() / sample_size

        # use coord masked loss for model training,
        # ignoring those position with missing coords (as nan)
        if label_mask is not None:
            label_mask = label_mask.float()
            sample_size = label_mask.sum()  # sample size should be set to valid coordinates
            loss = (loss * label_mask).sum() / sample_size
            nll_loss = (nll_loss * label_mask).sum() / sample_size
            acc = ((scores.argmax(-1) == target) * label_mask).sum() / sample_size
        else:
            loss, nll_loss = fullseq_loss, fullseq_nll_loss

        ppl = torch.exp(nll_loss)

        logging_output = {
            'nll_loss': nll_loss.data,
            'ppl': ppl.data,
            'acc': acc.data,
            'fullseq_loss': fullseq_loss.data,
            'fullseq_nll_loss': fullseq_nll_loss.data,
            'bsz': bsz,
            'sample_size': sample_size,
            'sample_ratio': sample_size / n_tokens,
            'nonpad_ratio': n_nonpad_tokens / n_tokens
        }
        return loss, logging_output

import torch
from torch import Tensor, nn
from torch.nn import functional as F


class CrossEntropyLoss1(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction

    def forward(self, scores: Tensor, target: Tensor, mask=None) -> Tensor:
        """
        计算 target 和 scores 具有相同维度的损失。

        参数
        ----
        scores: [B, N, C]，预测的 logits 或概率分布
        target: [B, N, C]，真实的概率分布
        mask: [B, N]，可选的掩码，用于忽略某些位置的损失计算

        返回
        ----
        loss: Tensor，计算的损失值
        """
        # 确保 scores 和 target 的形状一致
        assert scores.shape == target.shape, "scores 和 target 的形状必须一致"

        # 如果需要掩码，应用掩码
        if mask is not None:
            mask = mask.unsqueeze(-1)  # 扩展维度以匹配 [B, N, C]
            scores = scores * mask
            target = target * mask

        # 计算 KL 散度损失（假设 target 是概率分布）
        log_scores = F.log_softmax(scores, dim=-1)  # 对 scores 取 log_softmax
        loss = F.kl_div(log_scores, target, reduction='none')  # KL 散度

        # 如果需要 reduction，进行归约
        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()

        return loss

class Coord2SeqCrossEntropyLoss1(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction

    def forward(self, scores: Tensor, target: Tensor, mask=None, weights=None) -> Tensor:
        """
        计算 target 和 scores 具有相同维度的损失。

        参数
        ----
        scores: [B, N, C]，预测的 logits 或概率分布
        target: [B, N, C]，真实的概率分布
        mask: [B, N]，可选的掩码，用于忽略某些位置的损失计算
        weights: [B, N]，可选的权重，用于加权损失

        返回
        ----
        loss: Tensor，计算的损失值
        """
        # 确保 scores 和 target 的形状一致
        assert scores.shape == target.shape, "scores 和 target 的形状必须一致"

        # 计算 sample_size
        if mask is not None:
            sample_size = mask.sum()  # 有效位置的总数
        else:
            sample_size = scores.shape[0] * scores.shape[1]  # B * N

        # 如果需要掩码，应用掩码
        if mask is not None:
            mask = mask.unsqueeze(-1)  # 扩展维度以匹配 [B, N, C]
            scores = scores * mask
            target = target * mask

        # 计算 KL 散度损失（假设 target 是概率分布）
        log_scores = F.log_softmax(scores, dim=-1)  # 对 scores 取 log_softmax
        loss = F.kl_div(log_scores, target, reduction='none')  # KL 散度

        # 如果提供了权重，应用权重
        if weights is not None:
            weights = weights.unsqueeze(-1)  # 扩展维度以匹配 [B, N, C]
            loss = loss * weights

        # 如果需要 reduction，进行归约
        if self.reduction == 'mean':
            loss = loss.sum() / sample_size  # 使用 sample_size 归一化
        elif self.reduction == 'sum':
            loss = loss.sum()

        logging_output = {
            'nll_loss': loss.data,
            'sample_size': sample_size,
        }
        return loss, logging_output