"""
AgentSteerTTS - Complete Model Implementation

核心模块:
1. ADM (Adversarial Disentanglement Module): 说话人-情感解耦
2. DAC (Dual-stream Anchoring Controller): 双流锚定控制
3. Fast-Slow Feedback: 推理时强度校准

基于 OmniTTS (indextts/gpt/model_v2.py) 实现
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple


# =============================================================================
# 梯度反转层 (GRL) - ADM 核心组件
# =============================================================================

class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.save_for_backward(x, alpha)
        return x
    
    @staticmethod
    def backward(ctx, grad_output):
        _, alpha = ctx.saved_tensors
        return -alpha * grad_output, None


class GradientReversal(nn.Module):
    """梯度反转层，用于对抗训练"""
    def __init__(self, alpha: float = 1.0):
        super().__init__()
        self.alpha = torch.tensor(alpha, requires_grad=False)
        
    def forward(self, x):
        return GradientReversalFunction.apply(x, self.alpha)


# =============================================================================
# 条件编码器
# =============================================================================

class AttentionBlock(nn.Module):
    def __init__(self, channels: int, num_heads: int = 4):
        super().__init__()
        self.attn = nn.MultiheadAttention(channels, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(channels)
        
    def forward(self, x):
        x = x.transpose(1, 2)
        attn_out, _ = self.attn(x, x, x)
        x = self.norm(x + attn_out)
        return x.transpose(1, 2)


class ConditioningEncoder(nn.Module):
    """说话人/情感条件编码器"""
    def __init__(self, input_dim=1024, output_dim=512, num_attn_blocks=6, num_heads=4, mean_pooling=False):
        super().__init__()
        self.init_conv = nn.Conv1d(input_dim, output_dim, kernel_size=1)
        self.attn_blocks = nn.ModuleList([AttentionBlock(output_dim, num_heads) for _ in range(num_attn_blocks)])
        self.mean_pooling = mean_pooling
        self.output_dim = output_dim
        
    def forward(self, x):
        h = self.init_conv(x)
        for attn_block in self.attn_blocks:
            h = attn_block(h)
        return h.mean(dim=2) if self.mean_pooling else h


class PerceiverResampler(nn.Module):
    """Perceiver 重采样器，将变长特征投影到固定长度"""
    def __init__(self, dim=512, dim_context=512, num_latents=32, num_heads=8, ff_mult=4):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim) * 0.02)
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, kdim=dim_context, vdim=dim_context, batch_first=True)
        self.ff = nn.Sequential(nn.Linear(dim, dim * ff_mult), nn.GELU(), nn.Linear(dim * ff_mult, dim))
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
    def forward(self, context, mask=None):
        batch_size = context.shape[0]
        latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
        attn_out, _ = self.cross_attn(latents, context, context, key_padding_mask=mask)
        latents = self.norm1(latents + attn_out)
        latents = self.norm2(latents + self.ff(latents))
        return latents


class Discriminator(nn.Module):
    """对抗训练判别器"""
    def __init__(self, input_dim=512, hidden_dim=256, num_classes=10, dropout=0.1):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(self, x):
        return self.classifier(x)


class LatentConsistencyPredictor(nn.Module):
    """LCP: Fast Agent 的潜一致性预测器"""
    def __init__(self, mel_dim=80, hidden_dim=256, output_dim=768):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv1d(mel_dim, hidden_dim, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1), nn.ReLU(),
        )
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        self.mel_dim = mel_dim
        
    def forward(self, mel):
        if mel.shape[-1] == self.mel_dim:
            mel = mel.transpose(-1, -2)
        features = self.feature_extractor(mel)
        features = features.mean(dim=-1)
        return self.predictor(features)


# =============================================================================
# AgentSteerTTS 主模型
# =============================================================================

class AgentSteerTTS(nn.Module):
    """
    AgentSteerTTS 完整模型
    
    集成组件:
    1. ADM: 对抗解耦 (说话人/情感分离)
    2. DAC: 双流锚定控制 (原型检索 + 融合)
    3. Fast-Slow Feedback: 推理时精调
    """
    
    def __init__(
        self,
        model_dim: int = 512,
        semantic_dim: int = 1024,
        text_dim: int = 768,
        mel_dim: int = 80,
        gpt_layers: int = 8,
        gpt_heads: int = 8,
        max_text_tokens: int = 120,
        max_mel_tokens: int = 250,
        num_text_tokens: int = 256,
        num_mel_codes: int = 8194,
        start_mel_token: int = 8192,
        stop_mel_token: int = 8193,
        num_speakers: int = 100,
        num_emotions: int = 8,
        lambda_adv: float = 0.1,
        lambda_orth: float = 0.01,
        num_basic_emotions: int = 6,
        condition_num_latent: int = 32,
        fast_lr: float = 5e-3,
        fast_iterations: int = 2
    ):
        super().__init__()
        
        self.model_dim = model_dim
        self.semantic_dim = semantic_dim
        self.num_emotions = num_emotions
        self.num_basic_emotions = num_basic_emotions
        self.start_mel_token = start_mel_token
        self.stop_mel_token = stop_mel_token
        self.max_mel_tokens = max_mel_tokens
        
        # ===== 说话人条件编码 (E_id 路径) =====
        self.speaker_encoder = ConditioningEncoder(input_dim=semantic_dim, output_dim=model_dim, num_attn_blocks=6, num_heads=8)
        self.speaker_perceiver = PerceiverResampler(dim=model_dim, dim_context=model_dim, num_latents=condition_num_latent)
        
        # ===== 情感条件编码 (E_emo 路径) =====
        self.emotion_encoder = ConditioningEncoder(input_dim=semantic_dim, output_dim=model_dim, num_attn_blocks=4, num_heads=8)
        self.emotion_perceiver = PerceiverResampler(dim=semantic_dim, dim_context=model_dim, num_latents=1)
        self.emovec_layer = nn.Linear(semantic_dim, model_dim)
        self.emo_layer = nn.Linear(model_dim, model_dim)
        
        # ===== ADM: 对抗判别器 =====
        self.grl_for_speaker = GradientReversal(alpha=1.0)
        self.grl_for_emotion = GradientReversal(alpha=1.0)
        self.speaker_discriminator = Discriminator(input_dim=model_dim, num_classes=num_speakers)
        self.emotion_discriminator = Discriminator(input_dim=model_dim, num_classes=num_emotions)
        self.lambda_adv = lambda_adv
        self.lambda_orth = lambda_orth
        
        # ===== DAC: 情感原型矩阵 =====
        self.emo_matrix = nn.ParameterList([nn.Parameter(torch.randn(num_speakers, semantic_dim) * 0.02) for _ in range(num_basic_emotions)])
        self.spk_matrix = nn.ParameterList([nn.Parameter(torch.randn(num_speakers, 256) * 0.02) for _ in range(num_basic_emotions)])
        
        # ===== GPT 组件 =====
        self.text_embedding = nn.Embedding(num_text_tokens + 1, model_dim)
        self.mel_embedding = nn.Embedding(num_mel_codes, model_dim)
        self.text_pos_embedding = nn.Embedding(max_text_tokens + 2, model_dim)
        self.mel_pos_embedding = nn.Embedding(max_mel_tokens + 2, model_dim)
        self.speed_emb = nn.Embedding(2, model_dim)
        self.speed_emb.weight.data.normal_(mean=0.0, std=0.0)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=gpt_heads, dim_feedforward=model_dim * 4, batch_first=True)
        self.gpt = nn.TransformerEncoder(encoder_layer, num_layers=gpt_layers)
        self.final_norm = nn.LayerNorm(model_dim)
        self.mel_head = nn.Linear(model_dim, num_mel_codes)
        
        # ===== LCP: Fast Agent =====
        self.lcp = LatentConsistencyPredictor(mel_dim=mel_dim, hidden_dim=256, output_dim=text_dim)
        self.fast_lr = fast_lr
        self.fast_iterations = fast_iterations
        
        self.text_embedding.weight.data.normal_(mean=0.0, std=0.02)
        self.mel_embedding.weight.data.normal_(mean=0.0, std=0.02)

    # ========== ADM 方法 ==========
    
    def get_speaker_conditioning(self, semantic_features, lengths=None):
        """提取说话人条件 (z_id)"""
        if semantic_features.shape[-1] == self.semantic_dim:
            semantic_features = semantic_features.transpose(1, 2)
        h = self.speaker_encoder(semantic_features).transpose(1, 2)
        mask = None
        if lengths is not None:
            max_len = h.shape[1]
            mask = torch.arange(max_len, device=h.device).unsqueeze(0) >= lengths.unsqueeze(1)
        return self.speaker_perceiver(h, mask)
    
    def get_emotion_conditioning(self, semantic_features, lengths=None):
        """提取情感条件 (z_emo)"""
        if semantic_features.shape[-1] == self.semantic_dim:
            semantic_features = semantic_features.transpose(1, 2)
        h = self.emotion_encoder(semantic_features).transpose(1, 2)
        mask = None
        if lengths is not None:
            max_len = h.shape[1]
            mask = torch.arange(max_len, device=h.device) >= lengths.unsqueeze(1)
        emo_latent = self.emotion_perceiver(h, mask).squeeze(1)
        return self.emo_layer(self.emovec_layer(emo_latent))
    
    def compute_cross_covariance(self, z_id, z_emo):
        """计算批次交叉协方差 (用于 L_orth)"""
        batch_size = z_id.shape[0]
        z_id_centered = z_id - z_id.mean(dim=0, keepdim=True)
        z_emo_centered = z_emo - z_emo.mean(dim=0, keepdim=True)
        return torch.mm(z_id_centered.t(), z_emo_centered) / batch_size
    
    def compute_adm_losses(self, semantic_features, speaker_labels, emotion_labels, lengths=None):
        """计算 ADM 解耦损失: L_ADM = λ_adv * L_adv + λ_orth * L_orth"""
        z_id_full = self.get_speaker_conditioning(semantic_features, lengths)
        z_id = z_id_full.mean(dim=1)
        z_emo = self.get_emotion_conditioning(semantic_features, lengths)
        
        # 对抗损失
        z_emo_grl = self.grl_for_speaker(z_emo)
        l_adv_speaker = F.cross_entropy(self.speaker_discriminator(z_emo_grl), speaker_labels)
        z_id_grl = self.grl_for_emotion(z_id)
        l_adv_emotion = F.cross_entropy(self.emotion_discriminator(z_id_grl), emotion_labels)
        l_adv = l_adv_speaker + l_adv_emotion
        
        # 正交损失
        cov = self.compute_cross_covariance(z_id, z_emo)
        l_orth = torch.norm(cov, p='fro') ** 2
        
        return {'loss': self.lambda_adv * l_adv + self.lambda_orth * l_orth, 
                'l_adv': l_adv, 'l_orth': l_orth, 'z_id': z_id, 'z_emo': z_emo}

    # ========== DAC 方法 ==========
    
    def get_emovec_from_vector(self, emo_weights, speaker_style=None, random_select=False):
        """从离散情感权重计算情感向量 (DAC 向量控制路径)"""
        batch_size = emo_weights.shape[0]
        device = emo_weights.device
        selected_prototypes = []
        
        for k, emo_matrix in enumerate(self.emo_matrix):
            if random_select or speaker_style is None:
                idx = torch.randint(0, emo_matrix.shape[0], (batch_size,), device=device)
            else:
                spk_matrix = self.spk_matrix[k]
                similarities = F.cosine_similarity(speaker_style.unsqueeze(1), spk_matrix.unsqueeze(0), dim=2)
                idx = similarities.argmax(dim=1)
            selected_prototypes.append(emo_matrix[idx])
        
        prototypes = torch.stack(selected_prototypes, dim=1)
        weighted_proto = (prototypes * emo_weights.unsqueeze(-1)).sum(dim=1)
        return self.emo_layer(self.emovec_layer(weighted_proto))
    
    def fuse_emotion_vectors(self, audio_emo, vector_emo, speaker_base_emo, emo_alpha=1.0, emo_merge_alpha=1.0):
        """融合多个情感控制信号 (DAC 自适应融合, Eq. 9-10)"""
        if audio_emo is not None and vector_emo is not None:
            final_emo = audio_emo * emo_merge_alpha + vector_emo * (1 - emo_merge_alpha)
            return final_emo * emo_alpha + (1 - emo_alpha) * speaker_base_emo
        elif audio_emo is not None:
            return audio_emo * emo_alpha + (1 - emo_alpha) * speaker_base_emo
        elif vector_emo is not None:
            return vector_emo
        return None

    # ========== Fast Agent ==========
    
    def fast_loop_calibrate_alpha(self, z_id, z_emo, target_embedding, initial_alpha=1.0):
        """Fast Agent: 通过梯度下降校准强度 α (Eq. 11-12)"""
        batch_size = z_id.shape[0]
        device = z_id.device
        alpha = torch.full((batch_size, 1), initial_alpha, device=device)
        final_distance = torch.zeros(batch_size, device=device)
        
        with torch.enable_grad():
            alpha = alpha.clone().requires_grad_(True)
            for _ in range(self.fast_iterations):
                z_emo_scaled = alpha * z_emo
                z_combined = z_id.detach() + z_emo_scaled.unsqueeze(1)
                z_pooled = z_combined.mean(dim=1)
                fake_mel = z_pooled.unsqueeze(-1).expand(-1, -1, 100)
                fake_mel_proj = F.interpolate(fake_mel.unsqueeze(1), size=(80, 100), mode='bilinear', align_corners=False).squeeze(1)
                e_phi = self.lcp(fake_mel_proj)
                
                e_phi_norm = F.normalize(e_phi, dim=-1)
                target_norm = F.normalize(target_embedding.detach(), dim=-1)
                distance = 1 - (e_phi_norm * target_norm).sum(dim=-1)
                loss = distance.mean()
                final_distance = distance.detach()
                
                grad_alpha = torch.autograd.grad(loss, alpha, create_graph=False, retain_graph=False, allow_unused=True)[0]
                if grad_alpha is None:
                    grad_alpha = torch.zeros_like(alpha)
                alpha = (alpha.detach() - self.fast_lr * grad_alpha.detach()).clamp(0.0, 2.0).requires_grad_(True)
        
        return alpha.detach(), final_distance

    # ========== 主前向传播 ==========
    
    def forward(
        self,
        speaker_semantic: torch.Tensor,
        speaker_lengths: torch.Tensor,
        text_tokens: torch.Tensor,
        text_lengths: torch.Tensor,
        emo_semantic: Optional[torch.Tensor] = None,
        emo_lengths: Optional[torch.Tensor] = None,
        emo_vector: Optional[torch.Tensor] = None,
        emo_alpha: float = 1.0,
        emo_merge_alpha: float = 1.0,
        use_fast_calibration: bool = False,
        target_embedding: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        TTS 生成的前向传播
        
        Args:
            speaker_semantic: 说话人参考 W2V-BERT 特征 (B, C, T)
            speaker_lengths: 说话人特征长度 (B,)
            text_tokens: 输入文本 tokens (B, L)
            text_lengths: 文本长度 (B,)
            emo_semantic: 可选情感参考特征 (B, C, T')
            emo_lengths: 情感特征长度 (B,)
            emo_vector: 可选离散情感向量 (B, 6)
            emo_alpha: 情感强度缩放
            emo_merge_alpha: 音频/向量混合因子
            use_fast_calibration: 是否使用 Fast Agent α 校准
            target_embedding: Fast Agent 目标嵌入 (B, text_dim)
        """
        batch_size = speaker_semantic.shape[0]
        device = speaker_semantic.device
        
        # 提取说话人条件 (z_id)
        speaker_cond = self.get_speaker_conditioning(speaker_semantic, speaker_lengths)
        speaker_base_emo = self.get_emotion_conditioning(speaker_semantic, speaker_lengths)
        
        # 提取/计算情感控制 (z_emo)
        audio_emo = self.get_emotion_conditioning(emo_semantic, emo_lengths) if emo_semantic is not None else None
        vector_emo = self.get_emovec_from_vector(emo_vector) if emo_vector is not None else None
        
        # DAC: 融合情感信号
        emovec = self.fuse_emotion_vectors(audio_emo, vector_emo, speaker_base_emo, emo_alpha, emo_merge_alpha)
        
        # Fast Agent: 可选 α 校准
        alpha = torch.tensor(emo_alpha, device=device)
        if use_fast_calibration and emovec is not None and target_embedding is not None:
            alpha, _ = self.fast_loop_calibrate_alpha(speaker_cond, emovec, target_embedding, emo_alpha)
        
        # 构建最终条件
        duration_emb = self.speed_emb(torch.zeros(batch_size, dtype=torch.long, device=device))
        duration_emb_half = self.speed_emb(torch.ones(batch_size, dtype=torch.long, device=device))
        conditioning = speaker_cond + emovec.unsqueeze(1) if emovec is not None else speaker_cond
        conditioning = torch.cat([conditioning, duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)], dim=1)
        
        return {'conditioning': conditioning, 'emovec': emovec, 'alpha': alpha, 'z_id': speaker_cond, 'z_emo_base': speaker_base_emo}