import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    import torch.distributed.nn
    from torch import distributed as dist

    has_distributed = True
except ImportError:
    has_distributed = False

try:
    import horovod.torch as hvd
except ImportError:
    hvd = None


from .sinkhorn import sinkhorn_knopp, sinkhorn_knopp_unbalanced


def gather_features(
        audio_features,
        text_features,
        local_loss=False,
        gather_with_grad=False,
        rank=0,
        world_size=1,
        use_horovod=False
):
    assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
    if use_horovod:
        assert hvd is not None, 'Please install horovod'
        if gather_with_grad:
            all_audio_features = hvd.allgather(audio_features)
            all_text_features = hvd.allgather(text_features)
        else:
            with torch.no_grad():
                all_audio_features = hvd.allgather(audio_features)
                all_text_features = hvd.allgather(text_features)
            if not local_loss:
                batch_size = audio_features.shape[0]
                start = rank * batch_size
                end = start + batch_size
                all_audio_features[start:end] = audio_features
                all_text_features[start:end] = text_features
    else:
        if gather_with_grad:
            all_audio_features = torch.cat(torch.distributed.nn.all_gather(audio_features), dim=0)
            all_text_features  = torch.cat(torch.distributed.nn.all_gather(text_features),  dim=0)
        else:
            all_audio_features = torch.empty(
                world_size * audio_features.size(0), audio_features.size(1),
                dtype=audio_features.dtype, device=audio_features.device)
            all_text_features = torch.empty(
                world_size * text_features.size(0), text_features.size(1),
                dtype=text_features.dtype, device=text_features.device)
            dist.all_gather_into_tensor(all_audio_features, audio_features)
            dist.all_gather_into_tensor(all_text_features,  text_features)
            if not local_loss:
                batch_size = audio_features.size(0)
                start = rank * batch_size
                end = start + batch_size
                all_audio_features[start:end] = audio_features
                all_text_features[start:end] = text_features
    return all_audio_features, all_text_features


class Entropic_OT_Loss(nn.Module):
    """
    在原有 IOT + UWD 架构上，加入：
      - 可靠性加权的边缘先验 (a,b)
      - 批质量 gating (rho)
    不改网络，仅改变 UWD 的 marginal 初始化与总质量预算。
    """
    def __init__(self,
                 reg=0.03,
                 num_iter=10,
                 local_loss=False,
                 float_loss=False,
                 transfer_weight=0.0,
                 gather_with_grad=False,
                 rank=0,
                 world_size=1,
                 use_horovod=False,
                 # ===== New parameters =====
                uwd_reg_m=0.5,              # KL mass regularization in UWD
                rel_temp=0.7,               # Reliability softmax temperature T (smaller → sharper)
                rel_floor=1e-6,             # Minimum mass floor per channel to avoid numerical issues
                rho0=0.9,                   # Base total mass budget ρ0 ∈ (0,1]
                snr_alpha=8.0,              # SNR gating Sigmoid slope
                snr_tau=0.0,                # SNR gating threshold (after margin normalization)
                use_uniform_marginals=False,# If True, fallback to uniform marginals
                topk_ratio=None,            # Optional: allocate mass only to top k% channels

                 ):
        super().__init__()
        self.reg = reg
        self.num_iter = num_iter
        self.local_loss = local_loss
        self.float_loss = float_loss
        self.transfer_weight = transfer_weight
        self.gather_with_grad = gather_with_grad
        self.rank = rank
        self.world_size = world_size
        self.use_horovod = use_horovod

        # 新增
        self.uwd_reg_m = uwd_reg_m
        self.rel_temp = rel_temp
        self.rel_floor = rel_floor
        self.rho0 = rho0
        self.snr_alpha = snr_alpha
        self.snr_tau = snr_tau
        self.use_uniform_marginals = use_uniform_marginals
        self.topk_ratio = topk_ratio

        # EMA 
        self.register_buffer('ema_var',  torch.tensor([]), persistent=False)
        self.register_buffer('ema_kurt', torch.tensor([]), persistent=False)
        self.ema_momentum = 0.97


    def get_features(self, audio_features, text_features):
        if self.world_size > 1:
            all_audio_features, all_text_features = gather_features(
                audio_features,
                text_features,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=self.rank,
                world_size=self.world_size,
                use_horovod=self.use_horovod,
            )
            if self.local_loss:
                return audio_features, all_text_features
            else:
                return all_audio_features, all_text_features
        else:
            return audio_features, text_features

    # ====== IOT ======
    def entropic_ot(self, audio_features, text_features):
        logits_per_audio = torch.cdist(audio_features, text_features, p=2)
        logits_per_audio = logits_per_audio / (logits_per_audio.max() + 1e-8)
        a = torch.ones(len(audio_features), device=audio_features.device)/len(audio_features)
        b = torch.ones(len(text_features), device=text_features.device)/len(text_features)
        dist_matrix = sinkhorn_knopp(logits_per_audio, a, b, reg=self.reg, numItermax=self.num_iter)
        return dist_matrix

    # ====== UWD（ ======
    @torch.no_grad()
    def compute_transfer_plan(self, Feat_M, a, b, reg=0.1, reg_m=0.5, numItermax=10):

        return sinkhorn_knopp_unbalanced(Feat_M, a, b, reg=reg, reg_m=reg_m, numItermax=numItermax)

    @torch.no_grad()
    def _batch_snr(self, A, T):
        # A, T: [B, D]
        Dmat = torch.cdist(A, T, p=2)                    
        B = Dmat.size(0)
        pos = torch.arange(B, device=A.device)
        d_pos = Dmat[pos, pos]
    
        d_neg = (Dmat.sum(dim=1) - d_pos) / (Dmat.size(1) - 1)
       
        margin = (d_neg - d_pos) / (d_neg + 1e-6)
        snr = margin.mean()                               
        return snr.clamp(-1.0, 1.0)

    # ====== reliability r_d ∈ (0,1) ======
    @torch.no_grad()
    def _channel_reliability(self, U, V):
        # U, V: [B, D]  
        B, D = U.shape
        mu_u  = U.mean(dim=0)
        mu_v  = V.mean(dim=0)
        var_u = U.var(dim=0, unbiased=False)
        var_v = V.var(dim=0, unbiased=False)
        # 
        kurt_u = ((U - mu_u)**4).mean(dim=0) / (var_u.clamp_min(1e-6)**2 + 1e-6)
        kurt_v = ((V - mu_v)**4).mean(dim=0) / (var_v.clamp_min(1e-6)**2 + 1e-6)

        var = (var_u + var_v) * 0.5
        kurt = (kurt_u + kurt_v) * 0.5

        # EMA update
        if self.ema_var.numel() == 0:
            self.ema_var  = var.detach().clone()
            self.ema_kurt = kurt.detach().clone()
        else:
            m = self.ema_momentum
            self.ema_var  = m * self.ema_var  + (1 - m) * var.detach()
            self.ema_kurt = m * self.ema_kurt + (1 - m) * kurt.detach()

        # 
        s_var  = var  / (self.ema_var  + 1e-6)
        s_kurt = torch.abs(kurt - 3.0) / (torch.abs(self.ema_kurt - 3.0) + 1e-6)

        # corr
        u = (U - mu_u) / (var_u.sqrt() + 1e-6)
        v = (V - mu_v) / (var_v.sqrt() + 1e-6)
        corr_num = (u * v).sum(dim=0)
        corr_den = (u.pow(2).sum(dim=0).sqrt() * v.pow(2).sum(dim=0).sqrt() + 1e-6)
        s_corr = torch.abs(corr_num / corr_den).clamp(0, 1)

        # rank-normalize to [0,1]
        def ranknorm(x):
            ranks = torch.argsort(torch.argsort(x))
            return (ranks.float() + 1) / (x.numel() + 1)

        r = torch.sigmoid(
            ranknorm(s_corr) - ranknorm(s_var) - ranknorm(s_kurt)
        )
        return r  # [D],

    # ====== gating  (a,b) ======
    @torch.no_grad()
    def _build_marginals(self, U, V, rel_temp, rho0, snr_alpha, snr_tau, topk_ratio):
        D = U.size(1)
        r = self._channel_reliability(U, V).detach()  # [D]

        
        if topk_ratio is not None:
            k = max(1, int(D * float(topk_ratio)))
            thr = torch.topk(r, k).values.min()
            m = (r >= thr).float()
            r = r * m + 1e-9  # 保持正数

        a = torch.softmax((r / rel_temp), dim=0)
        b = torch.softmax((r / rel_temp), dim=0)

        snr = self._batch_snr(U, V)  # [-1,1]
        gate = torch.sigmoid(snr_alpha * (snr - snr_tau))  # (0,1)
        rho = rho0 * gate

        a = a * rho
        b = b * rho

        # 数值下界，避免严格 0
        a = torch.clamp(a, min=self.rel_floor)
        b = torch.clamp(b, min=self.rel_floor)

        # 不强制和为 1（UWD 允许质量不守恒）；如需，可归一化到和=rho
        # a = a * (rho / a.sum())
        # b = b * (rho / b.sum())

        return a, b, r, snr

    # ====== 特征级 UWD 损失（带可靠性先验的 a,b） ======
    def feature_transfer_loss(self, audio_features, text_features,
                              reg=None, reg_m=None, numItermax=None,
                              use_uniform_marginals=None):
        if reg is None: reg = self.reg
        if reg_m is None: reg_m = self.uwd_reg_m
        if numItermax is None: numItermax = self.num_iter
        if use_uniform_marginals is None: use_uniform_marginals = self.use_uniform_marginals

        # 通道间 cost 矩阵（简单用列向量两两欧氏距离）
        features_m = torch.cdist(audio_features.T, text_features.T, p=2)
        features_m = features_m / (features_m.max() + 1e-8)

        D_src, D_tgt = features_m.size(0), features_m.size(1)
        device = features_m.device

        if use_uniform_marginals:
            Feat_a = torch.ones(D_src, device=device) / D_src
            Feat_b = torch.ones(D_tgt, device=device) / D_tgt
            pi = self.compute_transfer_plan(features_m, Feat_a, Feat_b,
                                            reg=reg, reg_m=reg_m, numItermax=numItermax)
            transfer_loss = torch.sum(pi * features_m)
            meta = {"rho": 1.0, "snr": None}
            return transfer_loss, meta

        # 可靠性先验 + 批质量 gating
        with torch.no_grad():
            a, b, r, snr = self._build_marginals(audio_features, text_features,
                                                 self.rel_temp, self.rho0,
                                                 self.snr_alpha, self.snr_tau,
                                                 self.topk_ratio)
        # UWD：用先验 (a,b)
        pi = self.compute_transfer_plan(features_m, a, b,
                                        reg=reg, reg_m=reg_m, numItermax=numItermax)
        transfer_loss = torch.sum(pi * features_m)
        meta = {"rho": a.sum(), "snr": snr, "avg_r": r.mean()}
        return transfer_loss, meta

    # ====== 前向 ======
    def forward(self, audio_features, text_features, output_dict=False):
        batch_size = audio_features.size(0)
        device = audio_features.device

        audio_features, text_features = self.get_features(audio_features, text_features)

        # --- 样本级 IOT ---
        dist_matrix = self.entropic_ot(audio_features, text_features)
        if not self.local_loss:
            labels = torch.arange(len(dist_matrix), device=device)
            loss_iot = -torch.log(dist_matrix[labels, labels] + 1e-8).mean()
        else:
            labels = torch.arange(batch_size, device=device)
            loss_iot = -torch.log(dist_matrix[labels, labels + self.rank * batch_size] + 1e-8).mean()

        # --- 特征级 UWD（带可靠性先验） ---
        if self.float_loss or self.transfer_weight > 0:
            transfer_loss, meta = self.feature_transfer_loss(
                audio_features, text_features,
                reg=self.reg, reg_m=self.uwd_reg_m, numItermax=self.num_iter,
                use_uniform_marginals=self.use_uniform_marginals
            )
            total = loss_iot + self.transfer_weight * transfer_loss
            if output_dict:
                out = {
                    "entropic_ot_loss": loss_iot,
                    "transfer_loss": transfer_loss,
                    "rho_used": meta["rho"],
                    "batch_snr": meta["snr"],
                    "avg_channel_reliability": meta["avg_r"]
                }
                return out
            else:
                return total
        else:
            return {"entropic_ot_loss": loss_iot} if output_dict else loss_iot
