"""
Cross-Modal Adversarial Training (CMAT) Model Architecture
Generated by AI Research Agent for Agents4Science 2025
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class ProposedModel(nn.Module):
    def __init__(self, num_classes=10000, face_dim=2048, voice_dim=512, behavioral_dim=64, 
                 attention_dim=64, num_heads=8, dropout=0.5):
        super(ProposedModel, self).__init__()
        
        self.num_classes = num_classes
        self.face_dim = face_dim
        self.voice_dim = voice_dim
        self.behavioral_dim = behavioral_dim
        self.attention_dim = attention_dim
        self.num_heads = num_heads
        
        # Modality-specific encoders
        self.face_encoder = self._build_face_encoder()
        self.voice_encoder = self._build_voice_encoder()
        self.behavioral_encoder = self._build_behavioral_encoder()
        
        # Cross-modal attention mechanism
        self.cross_modal_attention = CrossModalAttention(
            face_dim, voice_dim, behavioral_dim, attention_dim, num_heads, dropout
        )
        
        # Adaptive fusion with adversarial detection
        self.adaptive_fusion = AdaptiveFusion(attention_dim, dropout)
        
        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(attention_dim * 3, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )
        
    def _build_face_encoder(self):
        """ResNet-50 based face encoder"""
        resnet = models.resnet50(pretrained=True)
        # Remove the final classification layer
        face_encoder = nn.Sequential(*list(resnet.children())[:-1])
        # Add projection layer to desired dimension
        face_encoder.add_module('projection', nn.Linear(2048, self.face_dim))
        return face_encoder
    
    def _build_voice_encoder(self):
        """1D CNN for voice features (MFCC)"""
        return nn.Sequential(
            nn.Conv1d(13, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(256, self.voice_dim)
        )
    
    def _build_behavioral_encoder(self):
        """MLP for behavioral features"""
        return nn.Sequential(
            nn.Linear(30, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, self.behavioral_dim)
        )
    
    def forward(self, face, voice, behavioral):
        # Extract features from each modality
        face_features = self.face_encoder(face)  # [batch_size, face_dim]
        voice_features = self.voice_encoder(voice)  # [batch_size, voice_dim]
        behavioral_features = self.behavioral_encoder(behavioral)  # [batch_size, behavioral_dim]
        
        # Cross-modal attention
        attended_features = self.cross_modal_attention(
            face_features, voice_features, behavioral_features
        )
        
        # Adaptive fusion
        fused_features = self.adaptive_fusion(attended_features)
        
        # Classification
        logits = self.classifier(fused_features)
        
        return logits, attended_features

class CrossModalAttention(nn.Module):
    def __init__(self, face_dim, voice_dim, behavioral_dim, attention_dim, num_heads, dropout):
        super(CrossModalAttention, self).__init__()
        
        self.attention_dim = attention_dim
        self.num_heads = num_heads
        self.head_dim = attention_dim // num_heads
        
        # Projection layers for each modality
        self.face_proj = nn.Linear(face_dim, attention_dim)
        self.voice_proj = nn.Linear(voice_dim, attention_dim)
        self.behavioral_proj = nn.Linear(behavioral_dim, attention_dim)
        
        # Multi-head attention
        self.attention = nn.MultiheadAttention(attention_dim, num_heads, dropout=dropout, batch_first=True)
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(attention_dim)
        
    def forward(self, face_features, voice_features, behavioral_features):
        batch_size = face_features.size(0)
        
        # Project features to attention dimension
        face_proj = self.face_proj(face_features)  # [batch_size, attention_dim]
        voice_proj = self.voice_proj(voice_features)  # [batch_size, attention_dim]
        behavioral_proj = self.behavioral_proj(behavioral_features)  # [batch_size, attention_dim]
        
        # Stack features for multi-head attention
        features = torch.stack([face_proj, voice_proj, behavioral_proj], dim=1)  # [batch_size, 3, attention_dim]
        
        # Self-attention
        attended_features, attention_weights = self.attention(features, features, features)
        
        # Residual connection and layer normalization
        attended_features = self.layer_norm(attended_features + features)
        
        return attended_features

class AdaptiveFusion(nn.Module):
    def __init__(self, attention_dim, dropout):
        super(AdaptiveFusion, self).__init__()
        
        # Gating mechanism for adaptive weights
        self.gate = nn.Sequential(
            nn.Linear(attention_dim, attention_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(attention_dim // 2, 1),
            nn.Sigmoid()
        )
        
        # Adversarial detection
        self.adversarial_detector = nn.Sequential(
            nn.Linear(attention_dim, attention_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(attention_dim // 2, 1),
            nn.Sigmoid()
        )
        
    def forward(self, attended_features):
        batch_size, num_modalities, attention_dim = attended_features.shape
        
        # Compute adaptive weights for each modality
        weights = []
        for i in range(num_modalities):
            weight = self.gate(attended_features[:, i, :])  # [batch_size, 1]
            weights.append(weight)
        
        weights = torch.cat(weights, dim=1)  # [batch_size, num_modalities]
        weights = F.softmax(weights, dim=1)  # Normalize weights
        
        # Adversarial detection
        adversarial_scores = []
        for i in range(num_modalities):
            score = self.adversarial_detector(attended_features[:, i, :])
            adversarial_scores.append(score)
        
        adversarial_scores = torch.cat(adversarial_scores, dim=1)  # [batch_size, num_modalities]
        
        # Adjust weights based on adversarial detection
        adjusted_weights = weights * (1 - adversarial_scores)  # Reduce weight for adversarial inputs
        
        # Normalize adjusted weights
        adjusted_weights = F.softmax(adjusted_weights, dim=1)
        
        # Apply weights to features
        fused_features = torch.sum(attended_features * adjusted_weights.unsqueeze(-1), dim=1)
        
        return fused_features

class AdversarialLoss(nn.Module):
    def __init__(self, alpha=0.1):
        super(AdversarialLoss, self).__init__()
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        
    def forward(self, logits_clean, logits_adv, labels):
        # Standard classification loss
        ce_loss = self.ce_loss(logits_clean, labels)
        
        # Adversarial loss (encourage similar predictions)
        adv_loss = F.mse_loss(logits_clean, logits_adv)
        
        return ce_loss + self.alpha * adv_loss

class ConsistencyLoss(nn.Module):
    def __init__(self, beta=0.05):
        super(ConsistencyLoss, self).__init__()
        self.beta = beta
        
    def forward(self, features_clean, features_adv):
        # Encourage feature consistency between clean and adversarial inputs
        return self.beta * F.mse_loss(features_clean, features_adv)

if __name__ == '__main__':
    print("Testing CMAT Model...")
    
    # Test model
    model = ProposedModel(num_classes=100, face_dim=2048, voice_dim=512, behavioral_dim=64)
    
    # Create dummy inputs
    batch_size = 4
    face_input = torch.randn(batch_size, 3, 224, 224)
    voice_input = torch.randn(batch_size, 13, 100)  # MFCC features
    behavioral_input = torch.randn(batch_size, 30)
    
    # Forward pass
    logits, features = model(face_input, voice_input, behavioral_input)
    
    print(f"Model output shape: {logits.shape}")
    print(f"Feature shape: {features.shape}")
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
    print("Model test complete.")
