import torch
import torch.nn as nn
import math

def kronecker_fusion(a, b):
    # a: [B, T, D1], b: [B, T, D2] -> [B, T, D1*D2]
    B, T, D1 = a.shape
    D2 = b.shape[2]
    return torch.einsum('bti,btj->btij', a, b).reshape(B, T, D1 * D2)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class AudioPoseAttention(nn.Module):
    def __init__(self, audio_dim=512, pose_dim=6, attn_dim=256):
        super().__init__()
        self.audio_proj = nn.Linear(audio_dim, attn_dim)
        self.pose_proj = nn.Linear(pose_dim, attn_dim)
        self.attention = nn.MultiheadAttention(attn_dim, num_heads=8, batch_first=True)
        
    def forward(self, audio_feats, pose_feats):
        audio_proj = self.audio_proj(audio_feats)
        pose_proj = self.pose_proj(pose_feats)
        
        attn_output, _ = self.attention(query=pose_proj, 
                                        key=audio_proj, 
                                        value=audio_proj)
        return attn_output

class TransformerPoseRefiner(nn.Module):
    def __init__(self, pose_dim=6, audio_dim=512, exp_dim=64, hidden_dim=256, nhead=8, 
                 num_encoder_layers=4, num_decoder_layers=4, dropout=0.1, fusion_type='kron'):
        super().__init__()
        self.pose_dim = pose_dim
        self.audio_dim = audio_dim
        self.exp_dim = exp_dim
        self.hidden_dim = hidden_dim
        self.fusion_type = fusion_type

        self.pose_embedding = nn.Linear(pose_dim, hidden_dim)
        self.audio_embedding = nn.Linear(audio_dim, hidden_dim)
        self.exp_embedding = nn.Linear(exp_dim, hidden_dim)

        self.pos_encoder = PositionalEncoding(hidden_dim)

        # Feature fusion for pose+exp
        if fusion_type == 'kron':
            self.fusion_dim = hidden_dim * hidden_dim
            self.fusion_proj = nn.Linear(self.fusion_dim, hidden_dim)
        elif fusion_type == 'concat':
            self.fusion_dim = hidden_dim * 2
            self.fusion_proj = nn.Linear(self.fusion_dim, hidden_dim)
        else:
            raise ValueError("fusion_type must be 'kron' or 'concat'")

        self.cross_attention = AudioPoseAttention(hidden_dim, hidden_dim, hidden_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, 
            nhead=nhead, 
            dim_feedforward=hidden_dim*4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_encoder_layers
        )
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim, 
            nhead=nhead, 
            dim_feedforward=hidden_dim*4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer, 
            num_layers=num_decoder_layers
        )
        
        self.output = nn.Linear(hidden_dim, pose_dim)
        
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        
    def forward(self, retrieved_pose, audio_embed, exp_feat):
        # retrieved_pose: [B, T, 1, pose_dim]
        # audio_embed: [B, T, audio_dim]
        # exp_feat: [B, T, exp_dim]
        pose_embed = self.pose_embedding(retrieved_pose.squeeze(2))  # [B, T, hidden_dim]
        audio_embed = self.audio_embedding(audio_embed)   # [B, T, hidden_dim]
        exp_embed = self.exp_embedding(exp_feat)          # [B, T, hidden_dim]

        # Feature fusion (pose+exp)
        if self.fusion_type == 'kron':
            pose_exp_fused = kronecker_fusion(pose_embed, exp_embed)
        else:
            pose_exp_fused = torch.cat([pose_embed, exp_embed], dim=-1)
        pose_exp_fused = self.fusion_proj(pose_exp_fused)  # [B, T, hidden_dim]

        pose_exp_fused = self.pos_encoder(pose_exp_fused)
        audio_embed = self.pos_encoder(audio_embed)

        pose_memory = self.transformer_encoder(pose_exp_fused)
        audio_memory = self.transformer_encoder(audio_embed)

        cross_features = self.cross_attention(audio_memory, pose_memory)
        enhanced_pose = pose_memory + cross_features
        enhanced_pose = self.norm1(enhanced_pose)

        decoded = self.transformer_decoder(enhanced_pose, audio_memory)
        decoded = self.norm2(decoded)

        refined_pose_delta = self.output(decoded)
        retrieved_pose = retrieved_pose.squeeze(2)  # [B, T, pose_dim]
        refined_pose = retrieved_pose + refined_pose_delta

        return refined_pose

class TransformerPoseDiscriminator(nn.Module):
    def __init__(self, pose_dim=6, audio_dim=512, exp_dim=64, hidden_dim=256, nhead=8, num_layers=4, dropout=0.1, fusion_type='kron'):
        super().__init__()
        self.pose_embedding = nn.Linear(pose_dim, hidden_dim)
        self.audio_embedding = nn.Linear(audio_dim, hidden_dim)
        self.exp_embedding = nn.Linear(exp_dim, hidden_dim)
        self.fusion_type = fusion_type

        self.pos_encoder = PositionalEncoding(hidden_dim)
        if fusion_type == 'kron':
            self.fusion_dim = hidden_dim * hidden_dim
            self.fusion_proj = nn.Linear(self.fusion_dim, hidden_dim)
        else:
            self.fusion_dim = hidden_dim * 2
            self.fusion_proj = nn.Linear(self.fusion_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, 
            nhead=nhead, 
            dim_feedforward=hidden_dim*4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        self.global_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.global_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
        self.local_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, pose, audio_embed, exp_feat):
        batch_size, seq_len = pose.shape[0], pose.shape[1]
        pose_embed = self.pose_embedding(pose.squeeze(2))  # [B, T, hidden_dim]
        audio_embed = self.audio_embedding(audio_embed)     # [B, T, hidden_dim]
        exp_embed = self.exp_embedding(exp_feat)            # [B, T, hidden_dim]

        # Feature fusion (pose+exp)
        if self.fusion_type == 'kron':
            pose_exp_fused = kronecker_fusion(pose_embed, exp_embed)
        else:
            pose_exp_fused = torch.cat([pose_embed, exp_embed], dim=-1)
        pose_exp_fused = self.fusion_proj(pose_exp_fused)  # [B, T, hidden_dim]

        pose_exp_fused = self.pos_encoder(pose_exp_fused)
        audio_embed = self.pos_encoder(audio_embed)

        features = pose_exp_fused + audio_embed
        
        global_tokens = self.global_token.expand(batch_size, -1, -1)  # [B, 1, hidden_dim]
        features_with_global = torch.cat([global_tokens, features], dim=1)  # [B, T+1, hidden_dim]
        
        encoded_features = self.transformer(features_with_global)  # [B, T+1, hidden_dim]
        
        global_feature = encoded_features[:, 0]  # [B, hidden_dim]
        local_features = encoded_features[:, 1:]  # [B, T, hidden_dim]

        global_validity = self.global_classifier(global_feature)  # [B, 1]
        local_validity = self.local_classifier(local_features)  # [B, T, 1]
        
        return global_validity, local_validity


class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla'):
        super().__init__()
        self.gan_mode = gan_mode
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['lsgan', 'wgan']:
            self.loss = None
            
    def __call__(self, prediction, target_is_real):
        if self.gan_mode == 'vanilla':
            target = torch.ones_like(prediction) if target_is_real else torch.zeros_like(prediction)
            return self.loss(prediction, target)
        elif self.gan_mode == 'lsgan':
            target = torch.ones_like(prediction) if target_is_real else torch.zeros_like(prediction)
            return torch.mean((prediction - target)**2)
        elif self.gan_mode == 'wgan':
            if target_is_real:
                return -torch.mean(prediction)
            else:
                return torch.mean(prediction)


class PoseGAN:
    def __init__(self, pose_dim=6, audio_dim=512, exp_dim=64, hidden_dim=256, device=None, fusion_type='kron'):
        self.device = device
        self.netG = TransformerPoseRefiner(pose_dim, audio_dim, exp_dim, hidden_dim, fusion_type=fusion_type).to(self.device)
        self.netD = TransformerPoseDiscriminator(pose_dim, audio_dim, exp_dim, hidden_dim, fusion_type=fusion_type).to(self.device)
        
        self.gan_loss = GANLoss(gan_mode='vanilla').to(self.device)
        self.l1_loss = nn.L1Loss()
        
        self.optimG = torch.optim.AdamW(self.netG.parameters(), lr=1e-4, betas=(0.5, 0.999))
        self.optimD = torch.optim.AdamW(self.netD.parameters(), lr=1e-3, betas=(0.5, 0.999))
        
        self.clip_value = 0.01
        
    def train_step(self, retrieved_pose, real_pose, audio_embed, exp_feat):
        batch_size = retrieved_pose.shape[0]
        
        retrieved_pose = retrieved_pose.to(self.device)
        real_pose = real_pose.to(self.device)
        real_pose = real_pose.squeeze(2)
        audio_embed = audio_embed.to(self.device)
        exp_feat = exp_feat.to(self.device)
        
        self.optimD.zero_grad()
        
        fake_pose = self.netG(retrieved_pose, audio_embed, exp_feat)
    
        real_global, real_local = self.netD(real_pose, audio_embed, exp_feat)
        fake_global, fake_local = self.netD(fake_pose.detach(), audio_embed, exp_feat)
        
        d_loss_global = self.gan_loss(real_global, True) + self.gan_loss(fake_global, False)
        d_loss_local = self.gan_loss(real_local, True) + self.gan_loss(fake_local, False)
        d_loss = d_loss_global + d_loss_local
        
        d_loss.backward()
        self.optimD.step()
        
        if self.gan_loss.gan_mode == 'wgan':
            for p in self.netD.parameters():
                p.data.clamp_(-self.clip_value, self.clip_value)

        self.optimG.zero_grad()
        
        fake_global, fake_local = self.netD(fake_pose, audio_embed, exp_feat)
        
        g_loss_gan = self.gan_loss(fake_global, True) + self.gan_loss(fake_local, True)
        
        g_loss_recon = self.l1_loss(fake_pose, real_pose.squeeze(2)) * 10.0
        
        velocity = fake_pose[:, 1:] - fake_pose[:, :-1]
        target_velocity = torch.zeros_like(velocity)
        g_loss_smooth = self.l1_loss(velocity, target_velocity) * 1.0
        
        g_loss = g_loss_gan + g_loss_recon + g_loss_smooth
        
        g_loss.backward()
        self.optimG.step()
        
        return {
            'd_loss': d_loss.item(),
            'g_loss': g_loss.item(),
            'g_loss_gan': g_loss_gan.item(),
            'g_loss_recon': g_loss_recon.item(),
            'g_loss_smooth': g_loss_smooth.item()
        }
        
    def refine_pose(self, retrieved_pose, audio_embed, exp_feat):
        """使用训练好的生成器优化姿态序列"""
        self.netG.eval()
        with torch.no_grad():
            retrieved_pose = retrieved_pose.to(self.device)
            audio_embed = audio_embed.to(self.device)
            exp_feat = exp_feat.to(self.device)
            refined_pose = self.netG(retrieved_pose, audio_embed, exp_feat)
            return refined_pose.cpu()
    
    def save_models(self, path):
        """保存模型"""
        torch.save({
            'generator': self.netG.state_dict(),
            'discriminator': self.netD.state_dict(),
            'optimG': self.optimG.state_dict(),
            'optimD': self.optimD.state_dict()
        }, path)
        
    def load_models(self, path):
        """加载模型"""
        checkpoint = torch.load(path, map_location=self.device)
        self.netG.load_state_dict(checkpoint['generator'])
        self.netD.load_state_dict(checkpoint['discriminator'])
        self.optimG.load_state_dict(checkpoint['optimG'])
        self.optimD.load_state_dict(checkpoint['optimD'])

        print(f"Loaded models from {path}")