import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv2dWithResidual(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act=True):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(cin, cout, kernel_size, stride, padding),
            nn.BatchNorm2d(cout)
        )
        self.act = nn.ReLU()
        self.residual = residual
        self.use_act = use_act

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        return self.act(out) if self.use_act else out


class FeatureGate(nn.Module):
    """特征选择门控机制"""
    def __init__(self, dim):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(dim, dim // 2),
            nn.ReLU(),
            nn.Linear(dim // 2, dim),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return x * self.gate(x)


class KProductFusion(nn.Module):
    """基于K-product的结构化特征融合模块"""
    def __init__(self, audio_dim, style_dim, emotion_dim, output_dim):
        super().__init__()
        
        # 降维投影，减少K-product导致的维度爆炸
        self.audio_proj = nn.Linear(audio_dim, 64)
        self.style_proj = nn.Linear(style_dim, 32)
        self.emotion_proj = nn.Linear(emotion_dim, 32)
        
        # 用于K-product后特征处理的层
        # 三组交互，每组输出128，拼接后384
        self.k_product_projection = nn.Sequential(
            nn.Linear(64*32 + 64*32 + 32*32, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim // 2)
        )
        
        # 用于处理原始特征的层
        self.direct_projection = nn.Sequential(
            nn.Linear(audio_dim + style_dim + emotion_dim, output_dim // 2),
            nn.ReLU()
        )
        
        # 融合模块
        self.fusion_gate = nn.Sequential(
            nn.Linear(output_dim, output_dim // 2),
            nn.ReLU(),
            nn.Linear(output_dim // 2, output_dim),
            nn.Sigmoid()
        )
        
        # 最终输出调整
        self.output_layer = nn.Linear(output_dim, output_dim)
        
    def kron_product(self, x, y):
        batch_size = x.size(0)
        # 使用更高效的方式计算Kronecker product
        kron = torch.bmm(x.unsqueeze(2), y.unsqueeze(1))
        return kron.view(batch_size, -1)
    
    def forward(self, audio_features, style_features, emotion_features):
        # 1. 投影降维
        audio_proj = self.audio_proj(audio_features)
        style_proj = self.style_proj(style_features)
        emotion_proj = self.emotion_proj(emotion_features)
        
        # 2. 计算两两K-product交互
        k_prod_as = self.kron_product(audio_proj, style_proj)      # [B, 64*32]
        k_prod_ae = self.kron_product(audio_proj, emotion_proj)    # [B, 64*32]
        k_prod_se = self.kron_product(style_proj, emotion_proj)    # [B, 32*32]
        k_prod = torch.cat([k_prod_as, k_prod_ae, k_prod_se], dim=-1)  # [B, 64*32+64*32+32*32]
        k_prod_features = self.k_product_projection(k_prod)
        
        # 3. 直接特征处理路径
        concat_features = torch.cat([audio_features, style_features, emotion_features], dim=-1)
        direct_features = self.direct_projection(concat_features)
        
        # 4. 融合两种特征
        combined = torch.cat([k_prod_features, direct_features], dim=-1)
        gate_weights = self.fusion_gate(combined)
        
        # 使用门控机制调节两种特征的权重
        gated_features = combined * gate_weights + combined * (1 - gate_weights)
        
        # 5. 最终输出
        output = self.output_layer(gated_features)
        return output


class TransformerTemporalBlock(nn.Module):
    """基于Transformer的时序建模模块"""
    def __init__(self, input_dim, nhead=4, num_layers=2, dim_feedforward=512):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True,
            norm_first=True,
            activation='gelu'
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(input_dim)
    def forward(self, x):
        # x: [B, T, C]
        x = self.norm(x)
        out = self.encoder(x)
        return out


class EmotionSyncK(nn.Module):
    def __init__(self, emotion_dim=8, style_dim=16):
        super().__init__()
        
        self.audio_encoder = nn.Sequential(
            Conv2dWithResidual(1, 32, kernel_size=3, stride=1, padding=1),
            Conv2dWithResidual(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2dWithResidual(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2dWithResidual(32, 64, kernel_size=3, stride=(3, 1), padding=1),
            Conv2dWithResidual(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2dWithResidual(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2dWithResidual(64, 128, kernel_size=3, stride=3, padding=1),
            Conv2dWithResidual(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2dWithResidual(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2dWithResidual(128, 256, kernel_size=3, stride=(3, 2), padding=1),
            Conv2dWithResidual(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2dWithResidual(256, 512, kernel_size=3, stride=1, padding=0),
            Conv2dWithResidual(512, 512, kernel_size=1, stride=1, padding=0),
        )
        
        self._load_pretrained_weights()

        self.emotion_encoder = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, emotion_dim)
        )

        self.style_encoder = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, style_dim)
        )

        combined_dim = 512 + style_dim + emotion_dim
        
        # 替换特征门控为K-product融合模块
        self.feature_fusion = KProductFusion(
            audio_dim=512, 
            style_dim=style_dim, 
            emotion_dim=emotion_dim,
            output_dim=combined_dim
        )
        
        # 保留原始特征门控作为辅助路径
        self.feature_gate = FeatureGate(combined_dim)
        
        # 融合路径选择参数
        self.fusion_alpha = nn.Parameter(torch.tensor(0.7))
        
        self.temporal_model = TransformerTemporalBlock(
            input_dim=combined_dim,
            nhead=4,
            num_layers=4,
            dim_feedforward=512
        )

        self.global_decoder = nn.Sequential(
            nn.Linear(combined_dim + 64, 128),
            nn.ReLU(),
            nn.Linear(128, 56)
        )

        self.eye_attention = nn.MultiheadAttention(embed_dim=combined_dim + 64, num_heads=2, batch_first=True)
        self.eye_decoder = nn.Sequential(
            nn.Linear(combined_dim + 64, 64),
            nn.ReLU(),
            nn.Linear(64, 8)
        )
        
        self.direct_mapping = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )

        self.alpha = nn.Parameter(torch.tensor(0.5))

        self.temporal_smooth = nn.Conv1d(64, 64, kernel_size=3, padding=1)
        
    def _load_pretrained_weights(self):
        wav2lip_state_dict = torch.load('/home/shizhaoxin/codebase/AudioDrivenGaussian/checkpoints/wav2lip.pth')['state_dict']
        state_dict = self.audio_encoder.state_dict()
        
        for k, v in wav2lip_state_dict.items():
            if 'audio_encoder' in k:
                state_dict[k.replace('module.audio_encoder.', '')] = v
        
        self.audio_encoder.load_state_dict(state_dict)
    
    def forward(self, audio, ref):
        batch_size, seq_len = audio.size(0), audio.size(1)
        
        audio_flat = audio.reshape(batch_size * seq_len, 1, 80, 16)
        audio_features = self.audio_encoder(audio_flat).view(batch_size * seq_len, -1)
        audio_features = audio_features.view(batch_size, seq_len, -1)

        direct_output = self.direct_mapping(audio_features)
        
        emotion_features = self.emotion_encoder(audio_features)
        ref_repeated = ref.repeat(1, seq_len, 1)
        style_features = self.style_encoder(ref_repeated.view(batch_size * seq_len, -1))
        style_features = style_features.view(batch_size, seq_len, -1)

        # 原始特征拼接
        combined_features = torch.cat([
            audio_features, 
            style_features, 
            emotion_features
        ], dim=-1)
        
        # 获取基于K-product的特征融合结果
        bs, length, _ = audio_features.shape
        k_fused_features = self.feature_fusion(
            audio_features.reshape(bs * length, -1), 
            style_features.reshape(bs * length, -1), 
            emotion_features.reshape(bs * length, -1)
        ).reshape(bs, length, -1)
        
        # 获取原始特征门控结果
        gated_features = self.feature_gate(combined_features)
        
        # 融合两种特征表示
        alpha = torch.sigmoid(self.fusion_alpha)
        fused_features = alpha * k_fused_features + (1 - alpha) * gated_features
        
        temporal_features = self.temporal_model(fused_features)
        
        cat_features = torch.cat([temporal_features, ref_repeated], dim=-1)
        cat_features_seq = cat_features.view(bs, length, -1)

        global_out = self.global_decoder(cat_features_seq.view(bs * length, -1)).view(bs, length, -1)

        eye_feat, _ = self.eye_attention(cat_features_seq, cat_features_seq, cat_features_seq)
        eye_out = self.eye_decoder(eye_feat.reshape(bs * length, -1)).reshape(bs, length, -1)

        full_out = []
        full_out.append(global_out[:, :, :22])   # 0-21
        full_out.append(eye_out)                 # 22-29
        full_out.append(global_out[:, :, 22:])   # 30-63
        complex_output = torch.cat(full_out, dim=-1)

        alpha = torch.sigmoid(self.alpha)
        expressions = 0.6 * alpha * direct_output + 0.4 * (1 - alpha) * complex_output

        expressions = expressions.transpose(1, 2)
        expressions = self.temporal_smooth(expressions)
        expressions = expressions.transpose(1, 2)
        
        return expressions
