# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

from dataclasses import dataclass, field

import torch
import torch.distributed as dist

from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass


def adjust_label_smoothed_nll_loss(lprobs, target, epsilon=0.0):
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
    if epsilon != 0:
        smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
        eps_i = epsilon / (lprobs.size(-1) - 1)
        loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
    else:
        loss = nll_loss
    return loss.mean()


@torch.no_grad()
def gather_without_grad(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.empty_like(tensor) for _ in range(dist.get_world_size())]
    dist.all_gather(tensors_gather, tensor, async_op=False)
    output = torch.cat(tensors_gather, dim=0)
    return output


@dataclass
class AudioTextRetrievalCriterionConfig(FairseqDataclass):
    label_smoothing: float = field(
        default=0.0,
        metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
    )


@register_criterion("audio_text_retrieval_criterion", dataclass=AudioTextRetrievalCriterionConfig)
class AudioTextRetrievalCriterion(FairseqCriterion):
    def __init__(self, task, label_smoothing=0.0):
        super().__init__(task)
        self.label_smoothing = label_smoothing

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        src_tokens = sample['net_input']['src_tokens']
        src_audios = sample['net_input']['src_audios']
        audio_padding_masks = sample['net_input']['audio_padding_masks']

        text_logits = model(src_tokens=src_tokens, encoder_type='text')
        audio_logits = model(
            src_audios=src_audios,
            audio_padding_masks=audio_padding_masks,
            encoder_type='audio'
        )
        text_logits_all = gather_without_grad(text_logits) if dist.is_initialized() else text_logits.data
        audio_logits_all = gather_without_grad(audio_logits) if dist.is_initialized() else audio_logits.data

        logit_scale_exp = model(return_logit_scale=True)

        # compute atc loss
        atc_loss, a2t_ncorrect, t2a_ncorrect = self.compute_atc_loss(
            audio_logits, text_logits,
            audio_logits_all, text_logits_all,
            logit_scale_exp
        )

        sample_size = 1
        logging_output = {
            "loss": atc_loss.data,
            "nsentences": sample['nsentences'],
            "sample_size": sample_size,
            "a2t_ncorrect": a2t_ncorrect,
            "t2a_ncorrect": t2a_ncorrect,
            "logit_scale_exp": logit_scale_exp
        }
        return atc_loss, sample_size, logging_output

    def compute_atc_loss(self, audio_logits, text_logits, audio_logits_all, text_logits_all, logit_scale_exp):
        slice_id = dist.get_rank() if dist.is_initialized() else 0
        bsz = audio_logits.size(0)
        start_idx = bsz * slice_id
        end_idx = start_idx + bsz
        targets = torch.arange(start_idx, end_idx).to(audio_logits.device)

        sim_a2t = logit_scale_exp * audio_logits @ text_logits_all.t()
        sim_t2a = logit_scale_exp * text_logits @ audio_logits_all.t()
        log_sim_a2t = utils.log_softmax(sim_a2t, dim=-1).type_as(sim_a2t)
        log_sim_t2a = utils.log_softmax(sim_t2a, dim=-1).type_as(sim_t2a)
        a2t_loss = adjust_label_smoothed_nll_loss(log_sim_a2t, targets, self.label_smoothing)
        t2a_loss = adjust_label_smoothed_nll_loss(log_sim_t2a, targets, self.label_smoothing)
        atc_loss = (a2t_loss + t2a_loss) / 2

        with torch.no_grad():
            a2t_preds = sim_a2t.argmax(dim=1)
            t2a_preds = sim_t2a.argmax(dim=1)
            a2t_ncorrect = (a2t_preds == targets).float().sum()
            t2a_ncorrect = (t2a_preds == targets).float().sum()

        return atc_loss, a2t_ncorrect, t2a_ncorrect

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        logit_scale_exp_sum = sum(log.get("logit_scale_exp", 0) for log in logging_outputs)
        nsentences = sum(log.get("nsentences", 1) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 1) for log in logging_outputs)

        metrics.log_scalar(
            "loss", loss_sum / sample_size, sample_size, round=3
        )
        metrics.log_scalar(
            "logit_scale_exp", logit_scale_exp_sum / sample_size, sample_size, round=3
        )
        metrics.log_scalar(
            "nsentences", nsentences, 1, round=3
        )
        metrics.log_scalar(
            "sample_size", sample_size, 1, round=3
        )

        if len(logging_outputs) > 0 and "a2t_ncorrect" in logging_outputs[0]:
            ncorrect = sum(log.get("a2t_ncorrect", 0) for log in logging_outputs)
            metrics.log_scalar(
                "a2t_accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
            )
        if len(logging_outputs) > 0 and "t2a_ncorrect" in logging_outputs[0]:
            ncorrect = sum(log.get("t2a_ncorrect", 0) for log in logging_outputs)
            metrics.log_scalar(
                "t2a_accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
            )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True
    

def gather_features(
        audio_features,
        text_features,
        local_loss=False,
        gather_with_grad=False,
        rank=0, 
        world_size=1,
):
    if gather_with_grad:
        all_audio_features = torch.cat(torch.distributed.nn.all_gather(audio_features), dim=0)
        all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
    else:
        # 直接分配好大的 tensor，避免 list 和 cat
        all_audio_features = torch.empty(
            world_size * audio_features.size(0), audio_features.size(1),
            dtype=audio_features.dtype, device=audio_features.device)
        all_text_features = torch.empty(
            world_size * text_features.size(0), text_features.size(1),
            dtype=text_features.dtype, device=text_features.device)
        # --· torch > 2.1 --
        # all_gather_into_tensor 直接 gather 到 all_image_features
        dist.all_gather_into_tensor(all_audio_features, audio_features)
        dist.all_gather_into_tensor(all_text_features, text_features)
        if not local_loss:
            batch_size = audio_features.size(0)
            start = rank * batch_size
            end = start + batch_size
            all_audio_features[start:end] = audio_features
            all_text_features[start:end] = text_features
        # gathered_audio_features = [torch.zeros_like(audio_features) for _ in range(world_size)]
        # gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
        # dist.all_gather(gathered_audio_features, audio_features)
        # dist.all_gather(gathered_text_features, text_features)
        # if not local_loss:
        #     # ensure grads for local rank when all_* features don't have a gradient
        #     gathered_audio_features[rank] = audio_features
        #     gathered_text_features[rank] = text_features
        # all_audio_features = torch.cat(gathered_audio_features, dim=0)
        # all_text_features = torch.cat(gathered_text_features, dim=0)

    return all_audio_features, all_text_features

def sinkhorn_knopp(
    cost_matrix: torch.Tensor,
    a: torch.Tensor = None,
    b: torch.Tensor = None,
    reg: float = 0.1,
    numItermax: int = 50,
    stop_thresh: float = 1e-3,
    log_domain: bool = True,  # 使用对数域计算增强数值稳定性
    eps: float = 1e-8
) -> torch.Tensor:
    """
    可导的Sinkhorn算法实现（支持GPU）
    
    参数:
        cost_matrix: (m, n) 代价矩阵
        a: (m,) 输入分布的权重（None时默认为均匀分布）
        b: (n,) 目标分布的权重（None时默认为均匀分布）
        reg: 正则化系数（熵正则化的λ）
        num_iter: 最大迭代次数
        stop_thresh: 早停阈值
        log_domain: 是否使用对数域计算
        eps: 数值稳定的小常数
    
    返回:
        transport_plan: (m, n) 传输矩阵（可导）
    """
    m, n = cost_matrix.shape
    device = cost_matrix.device
    
    # 默认均匀分布
    if a is None:
        a = torch.ones(m, device=device) / m
    if b is None:
        b = torch.ones(n, device=device) / n
    
    if log_domain:
        # --- 对数域实现（数值更稳定）---
        log_a = torch.log(a + eps)
        log_b = torch.log(b + eps)
        log_K = -cost_matrix / reg  # K = exp(-C/reg)
        
        u = torch.zeros_like(a)
        v = torch.zeros_like(b)
        
        for _ in range(numItermax):
            u_prev = u.clone()
            
            # 更新u和v（对数域）
            v = log_b - torch.logsumexp(log_K + u.unsqueeze(1), dim=0)
            u = log_a - torch.logsumexp(log_K + v.unsqueeze(0), dim=1)
            
            # 早停检查
            if (u - u_prev).abs().max() < stop_thresh:
                break
                
        # 计算传输矩阵 P = diag(exp(u)) * K * diag(exp(v))
        transport_plan = torch.exp(u.unsqueeze(1) + log_K + v.unsqueeze(0))
        
    else:
        # --- 原始域实现（需要显式数值保护）---
        K = torch.exp(-cost_matrix / reg)
        K = K / K.sum()  # 初步归一化
        
        u = torch.ones_like(a) / m
        v = torch.ones_like(b) / n
        
        for _ in range(numItermax):
            u_prev = u.clone()
            
            v = b / (K.T @ u + eps)
            u = a / (K @ v + eps)
            
            if (u - u_prev).abs().max() < stop_thresh:
                break
                
        transport_plan = u.unsqueeze(1) * K * v.unsqueeze(0)
    
    return transport_plan

def sinkhorn_knopp_unbalanced(
    cost_matrix: torch.Tensor,
    a: torch.Tensor = None,
    b: torch.Tensor = None,
    reg: float = 0.1,
    reg_m: float = 0.1,
    numItermax: int = 10,
    stop_thresh: float = 1e-3,
    log_domain: bool = True,
    eps: float = 1e-8
) -> torch.Tensor:
    """
    可导的Unbalanced Sinkhorn算法（支持GPU + log-domain）
    
    参数:
        cost_matrix: (m, n) torch.Tensor, 代价矩阵
        a: (m,) 输入分布（可不归一）
        b: (n,) 输出分布（可不归一）
        reg: 熵正则项 (ε)
        reg_m: 质量惩罚项 (τ)
        numItermax: 最大迭代次数
        stop_thresh: 早停阈值
        log_domain: 是否使用对数域计算（建议开启）
        eps: 稳定性小常数
    
    返回:
        transport_plan: (m, n) torch.Tensor
    """
    m, n = cost_matrix.shape
    device = cost_matrix.device

    if a is None:
        a = torch.ones(m, device=device) / m
    if b is None:
        b = torch.ones(n, device=device) / n

    # 平衡因子
    tau = reg_m / (reg_m + reg)
    
    if log_domain:
        log_K = -cost_matrix / reg
        u = torch.zeros_like(a)
        v = torch.zeros_like(b)

        log_a = torch.log(a + eps)
        log_b = torch.log(b + eps)

        for _ in range(numItermax):
            u_prev = u.clone()

            v = tau * (log_b - torch.logsumexp(log_K + u.unsqueeze(1), dim=0))
            u = tau * (log_a - torch.logsumexp(log_K + v.unsqueeze(0), dim=1))

            if (u - u_prev).abs().max() < stop_thresh:
                break

        # transport plan: diag(exp(u)) * K * diag(exp(v))
        transport_plan = torch.exp(u.unsqueeze(1) + log_K + v.unsqueeze(0))
    else:
        K = torch.exp(-cost_matrix / reg)
        u = torch.ones_like(a)
        v = torch.ones_like(b)

        for _ in range(numItermax):
            u_prev = u.clone()

            Kv = K @ v + eps
            u = (a / Kv) ** tau

            KTu = K.T @ u + eps
            v = (b / KTu) ** tau

            if (u - u_prev).abs().max() < stop_thresh:
                break

        transport_plan = u.unsqueeze(1) * K * v.unsqueeze(0)

    return transport_plan
@dataclass
class AudioTextRetrievalCriterionFloatConfig(FairseqDataclass):
    transfer_weight: float = field(default=0.0, metadata={"help": "weight of feature-level transfer loss"})
    reg: float = field(default=0.03, metadata={"help": "sinkhorn entropic regularization"})
    reg_m: float = field(default=0.5, metadata={"help": "mass relaxation regularization"})
    sinkhorn_iter: int = field(default=10, metadata={"help": "number of Sinkhorn iterations"})
    local_loss: bool = field(default=False, metadata={"help": "if True, compute local loss instead of global"})
    use_transfer: bool = field(default=True, metadata={"help": "if True, compute feature-level transfer loss"})


@register_criterion("audio_text_retrieval_criterion_float", dataclass=AudioTextRetrievalCriterionFloatConfig)
class AudioTextRetrievalCriterionFloat(FairseqCriterion):
    def __init__(self, task, reg, reg_m, sinkhorn_iter, use_transfer, transfer_weight, local_loss):
        super().__init__(task)
        self.reg = reg
        self.reg_m = reg_m
        self.sinkhorn_iter = sinkhorn_iter
        self.use_transfer = use_transfer
        self.transfer_weight = transfer_weight
        self.local_loss = local_loss
    
    def forward(self, model, sample, reduce=True):
        src_tokens = sample['net_input']['src_tokens']
        src_audios = sample['net_input']['src_audios']
        audio_padding_masks = sample['net_input']['audio_padding_masks']

        text_logits = model(src_tokens=src_tokens, encoder_type='text')
        audio_logits = model(
            src_audios=src_audios,
            audio_padding_masks=audio_padding_masks,
            encoder_type='audio'
        )
        world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
        rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0

        all_audio, all_text = gather_features(
            audio_logits, text_logits,
            local_loss=self.local_loss,
            gather_with_grad=False,
            rank=rank,
            world_size=world_size,
        )
        
        # # 用缩放因子控制reg试试
        # logit_scale_exp = model(return_logit_scale=True)
        entropic_ot_loss, a2t_ncorrect, t2a_ncorrect = self.compute_sinkhorn_loss(
            audio_logits, text_logits,
            all_audio, all_text,
            rank=rank
        )
        if self.use_transfer:
            transfer_loss = self.computer_feature_transfer_loss(
                audio_logits, text_logits,
                all_audio, all_text
            )
            total_loss = entropic_ot_loss + self.transfer_weight * transfer_loss
        else:
            total_loss = entropic_ot_loss
        
        logging_output = {
            "loss": total_loss.item(),
            "entropic_ot_loss": entropic_ot_loss.item(),
            "transfer_loss": transfer_loss.item() if self.use_transfer else 0.0,
            "nsentences": sample['nsentences'],
            "sample_size": 1,
            "a2t_ncorrect": a2t_ncorrect,
            "t2a_ncorrect": t2a_ncorrect,
        }
        return total_loss, 1, logging_output



    def computer_feature_transfer_loss(self, audio_logits, text_logits, audio_logits_all, text_logits_all, logit_scale_exp=None):
        # features_m = torch.cdist(audio_logits_all.T, text_logits_all.T, p=2)
        audio_norm = (audio_logits_all ** 2).sum(dim=0, keepdim=True)
        text_norm = (text_logits_all ** 2).sum(dim=0, keepdim=True)
        features_m = torch.sqrt(torch.clamp(audio_norm.T + text_norm - 2.0 * (audio_norm.T @ text_norm), min=0.0))
        features_m = features_m / features_m.max()
        

        features_m_cpu = features_m.to("cpu").requires_grad_()
        pi = sinkhorn_knopp_unbalanced(features_m_cpu, reg=self.reg, reg_m=self.reg_m, numItermax=10)
        
        transfer_loss = torch.sum(pi * features_m_cpu)
        return transfer_loss.to(audio_logits_all.device)
    
    
    
    def compute_sinkhorn_loss(self, audio_logits, text_logits, audio_logits_all, text_logits_all, rank, logit_scale_exp=None):
        # logits_per_audio = torch.cdist(audio_logits_all, text_logits_all, p=2)
        # logits_per_audio = logits_per_audio / logits_per_audio.max()
        logits_per_audio = torch.matmul(audio_logits_all, text_logits_all.T)
        cost_per_audio = (1.0 - logits_per_audio) / 2
        
        
        # uniform distributions
        a = torch.ones(len(audio_logits_all), device=audio_logits_all.device)/len(audio_logits_all)
        b = torch.ones(len(text_logits_all), device=text_logits_all.device)/len(text_logits_all)

        dist_matrix = sinkhorn_knopp(cost_per_audio, a, b, reg=self.reg, numItermax=10)
        bsz = audio_logits.size(0)
        
        if not self.local_loss:
            labels = torch.arange(len(dist_matrix), device=audio_logits_all.device)
            entropic_ot_loss = -torch.log(dist_matrix[labels, labels] + 1e-8).mean()
        else:
            labels = torch.arange(bsz, device=audio_logits_all.device)
            entropic_ot_loss = -torch.log(dist_matrix[labels,labels+rank*bsz] + 1e-8).mean()
        
        with torch.no_grad():
            a2t_preds = dist_matrix.argmin(dim=1)
            t2a_preds = dist_matrix.argmin(dim=0)
            a2t_ncorrect = (a2t_preds == labels).float().sum()
            t2a_ncorrect = (t2a_preds == labels).float().sum()
        
        return entropic_ot_loss, a2t_ncorrect, t2a_ncorrect
      
                
        
    def compute_atc_loss(self, audio_logits, text_logits, audio_logits_all, text_logits_all, logit_scale_exp):
        slice_id = dist.get_rank() if dist.is_initialized() else 0
        bsz = audio_logits.size(0)
        start_idx = bsz * slice_id
        end_idx = start_idx + bsz
        targets = torch.arange(start_idx, end_idx).to(audio_logits.device)

        sim_a2t = logit_scale_exp * audio_logits @ text_logits_all.t()
        sim_t2a = logit_scale_exp * text_logits @ audio_logits_all.t()
        log_sim_a2t = utils.log_softmax(sim_a2t, dim=-1).type_as(sim_a2t)
        log_sim_t2a = utils.log_softmax(sim_t2a, dim=-1).type_as(sim_t2a)
        a2t_loss = adjust_label_smoothed_nll_loss(log_sim_a2t, targets, self.label_smoothing)
        t2a_loss = adjust_label_smoothed_nll_loss(log_sim_t2a, targets, self.label_smoothing)
        atc_loss = (a2t_loss + t2a_loss) / 2

        with torch.no_grad():
            a2t_preds = sim_a2t.argmax(dim=1)
            t2a_preds = sim_t2a.argmax(dim=1)
            a2t_ncorrect = (a2t_preds == targets).float().sum()
            t2a_ncorrect = (t2a_preds == targets).float().sum()

        return atc_loss, a2t_ncorrect, t2a_ncorrect
    
    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        ot_loss_sum = sum(log.get("entropic_ot_loss", 0) for log in logging_outputs)
        transfer_loss_sum = sum(log.get("transfer_loss", 0) for log in logging_outputs)
        nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 1) for log in logging_outputs)

        metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=4)
        metrics.log_scalar("entropic_ot_loss", ot_loss_sum / sample_size, sample_size, round=4)
        metrics.log_scalar("transfer_loss", transfer_loss_sum / sample_size, sample_size, round=4)
        metrics.log_scalar("sample_size", sample_size, 1)
        metrics.log_scalar("nsentences", nsentences, 1)

        if len(logging_outputs) > 0 and "a2t_ncorrect" in logging_outputs[0]:
            ncorrect = sum(log.get("a2t_ncorrect", 0) for log in logging_outputs)
            metrics.log_scalar(
                "a2t_accuracy", 100.0 * ncorrect / nsentences, nsentences, round=2
            )
        if len(logging_outputs) > 0 and "t2a_ncorrect" in logging_outputs[0]:
            ncorrect = sum(log.get("t2a_ncorrect", 0) for log in logging_outputs)
            metrics.log_scalar(
                "t2a_accuracy", 100.0 * ncorrect / nsentences, nsentences, round=2
            )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`.
        """
        return True
