

from pathlib import Path

import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
from diffusers.models.modeling_utils import ModelMixin
import torch.nn.functional as F


class SoftmaxWithTemperature(nn.Module):
    def __init__(self, dim, temperature=2.0):
        super().__init__()
        self.dim = dim
        self.temperature = nn.Parameter(torch.ones(1) * temperature)
        
    def forward(self, x):
        return F.softmax(x / self.temperature, dim=self.dim)


class PoseNet0401(ModelMixin):
    def __init__(self, noise_latent_channels=320, predict_mask=False):
        super().__init__()
        self.predict_mask = predict_mask


        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.SiLU()
        )
        

        self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
        

        if self.predict_mask:
            self.mask_head = nn.Sequential(
                nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1)
            )
        

        self._initialize_weights()
        self.scale = nn.Parameter(torch.ones(1) * 2)
    
    def _initialize_weights(self):
        for m in self.conv_layers:
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
                if m.bias is not None:
                    init.zeros_(m.bias)
        init.zeros_(self.final_proj.weight)
        if self.final_proj.bias is not None:
            init.zeros_(self.final_proj.bias)
    
    def forward(self, x):
        if x.ndim == 5:
            x = einops.rearrange(x, "b f c h w -> (b f) c h w")
        features = self.conv_layers(x)
        

        if self.predict_mask:
            pred_mask = torch.sigmoid(self.mask_head(features))
        
        pose_features = self.final_proj(features) * self.scale
        

        if self.predict_mask:
            return pose_features, pred_mask
        return pose_features


class MultiPersonPoseNet0406(ModelMixin):
    def __init__(self, noise_latent_channels=320, num_persons=2, attention_temperature=2.0, mask_predict=False, residual_alpha=0.5):
        super().__init__()
        self.num_persons = num_persons


        self.pose_nets = nn.ModuleList([
            PoseNet0401(noise_latent_channels=noise_latent_channels) 
            for _ in range(num_persons)
        ])


        self.feature_norm = nn.LayerNorm([noise_latent_channels, 64, 64])


        self.integration_layer = nn.Conv2d(
            in_channels=noise_latent_channels * num_persons,
            out_channels=noise_latent_channels,
            kernel_size=1
        )
        

        self.attention_pre = nn.Sequential(
            nn.Conv2d(noise_latent_channels * num_persons, 128, kernel_size=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, num_persons, kernel_size=1)
        )
        self.attention = SoftmaxWithTemperature(dim=1, temperature=attention_temperature)
        

        self.mask_processor = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.BatchNorm2d(4),
            nn.Conv2d(4, 3, kernel_size=3, padding=1)
        )


        self.pose_gate = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=1),
            nn.SiLU(),
            nn.Conv2d(8, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
        self.mask_gate = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=1),
            nn.SiLU(),
            nn.Conv2d(8, 1, kernel_size=1),
            nn.Sigmoid()
        )
        

        self.pose_weight = nn.Parameter(torch.ones(1) * 0.8)
        

        self.residual_alpha = nn.Parameter(torch.ones(1) * residual_alpha)
        

        self.pose_enhancer = nn.Sequential(
            nn.Conv2d(noise_latent_channels, noise_latent_channels, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.BatchNorm2d(noise_latent_channels),
            nn.Conv2d(noise_latent_channels, noise_latent_channels, kernel_size=1)
        )
        

        self.pose_scale = nn.Parameter(torch.ones(1) * 1.5)
        
        self._initialize_weights()
    
    def _initialize_weights(self):

        nn.init.kaiming_normal_(self.integration_layer.weight, mode='fan_out', nonlinearity='relu')
        nn.init.zeros_(self.integration_layer.bias)
        
        for m in self.attention_pre:
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.zeros_(m.bias)


        for m in self.pose_enhancer:
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.zeros_(m.bias)
    
    def forward(self, person_poses, person_masks=None):
        """
        修改forward以增强姿态特征的影响
        
        Args:
        person_poses: [num_persons, batch_size, frames, 3, H, W]
        person_masks: [num_persons, batch_size, frames, 1, H, W] 可选
        """
        print(f'in multi posenet, {person_poses.shape=}')
        print(f'in multi posenet, {person_masks.shape if person_masks is not None else None=}')

        batch_size = person_poses.shape[1]
        num_frames = person_poses.shape[2]
        
        person_features = []
        person_pred_masks = []
        

        for i, pose_net in enumerate(self.pose_nets):
            poses = person_poses[i]
            poses_flat = einops.rearrange(poses, "b f c h w -> (b f) c h w")
            

            if person_masks is not None:
                mask = person_masks[i]
                mask_flat = einops.rearrange(mask, "b f c h w -> (b f) c h w")
                

                raw_pose_features = pose_net(poses_flat)
                if isinstance(raw_pose_features, tuple):
                    raw_pose_features = raw_pose_features[0]
                

                mask_features = self.mask_processor(mask_flat)
                if mask_features.shape[-2:] != poses_flat.shape[-2:]:
                    mask_features = F.interpolate(
                        mask_features, 
                        size=poses_flat.shape[-2:],
                        mode='bilinear', 
                        align_corners=False
                    )
                

                pose_importance = self.pose_gate(poses_flat)
                mask_importance = self.mask_gate(mask_features)
                

                pose_importance = pose_importance * self.pose_weight
                mask_importance = mask_importance * (1 - self.pose_weight)
                

                weighted_poses = poses_flat * pose_importance
                weighted_masks = mask_features * mask_importance
                

                fused_input = weighted_poses + weighted_masks


                features = raw_pose_features + self.residual_alpha * pose_net(fused_input)


                pred_mask = torch.ones(
                    (poses_flat.size(0), 1, poses_flat.size(2) // 8, poses_flat.size(3) // 8),
                    device=poses_flat.device,
                    dtype=poses_flat.dtype
                )
                person_pred_masks.append(pred_mask)
                
            else:

                result = pose_net(poses_flat)
                

                if isinstance(result, tuple):
                    features, pred_mask = result


                    mask_features = self.mask_processor(pred_mask)
                    

                    pose_importance = self.pose_gate(poses_flat)
                    mask_importance = self.mask_gate(mask_features)
                    

                    pose_importance = pose_importance * self.pose_weight * 1.2
                    mask_importance = mask_importance * (1 - self.pose_weight) * 0.8
                    

                    weighted_poses = poses_flat * pose_importance
                    weighted_masks = mask_features * mask_importance
                    

                    fused_input = weighted_poses + weighted_masks


                    fused_features = pose_net(fused_input)
                    if isinstance(fused_features, tuple):
                        fused_features = fused_features[0]
                    
                    features = features * 0.95 + fused_features * 0.05
                    
                    person_pred_masks.append(pred_mask)
                else:

                    features = result

                    pred_mask = torch.ones(
                        (poses_flat.size(0), 1, poses_flat.size(2) // 8, poses_flat.size(3) // 8),
                        device=poses_flat.device,
                        dtype=poses_flat.dtype
                    )
                    person_pred_masks.append(pred_mask)
            

            enhanced_features = self.pose_enhancer(features)

            features = features * self.pose_scale + enhanced_features * (1 - self.residual_alpha)
            
            person_features.append(features)
        

        normalized_features = []
        for features in person_features:
            norm_feat = self.feature_norm(features)
            print(f"Features: mean={norm_feat.mean().item():.6f}, std={norm_feat.std().item():.6f}")
            normalized_features.append(norm_feat)
        

        combined_features = torch.cat(normalized_features, dim=1)
        print(f"Combined features: mean={combined_features.mean().item():.6f}, std={combined_features.std().item():.6f}")


        if self.training:
            noise = torch.randn_like(combined_features) * 0.005
            combined_features_for_attention = combined_features + noise
        else:
            combined_features_for_attention = combined_features
        
        attention_logits = self.attention_pre(combined_features_for_attention)
        attention_weights = self.attention(attention_logits)
        print(f"Attention weights: mean={attention_weights.mean().item():.6f}, std={attention_weights.std().item():.6f}")


        weighted_features = []
        for i in range(self.num_persons):
            weighted = normalized_features[i] * attention_weights[:, i:i+1]
            weighted_features.append(weighted)
        

        combined_weighted = torch.sum(torch.stack(weighted_features), dim=0)
        integrated_features = self.integration_layer(combined_features)
        

        mean_features = torch.mean(torch.stack(normalized_features), dim=0)
        integrated_features = integrated_features * 0.95 + self.residual_alpha * mean_features

        print(f"Integrated features: mean={integrated_features.mean().item():.6f}, std={integrated_features.std().item():.6f}")


        final_features = einops.rearrange(
            integrated_features, 
            "(b f) c h w -> b f c h w", 
            b=batch_size, f=num_frames
        )
        print(f'{final_features.shape=}')


        if not self.training and len(person_pred_masks) > 0:

            combined_masks = torch.stack(person_pred_masks, dim=0)
            return final_features, combined_masks
        else:
            return final_features


