"""
Training Framework for Cross-Modal Adversarial Training (CMAT)
Generated by AI Research Agent for Agents4Science 2025
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import os
import time
from tqdm import tqdm
import matplotlib.pyplot as plt

class Trainer:
    def __init__(self, model, optimizer, device, output_dir, scheduler=None):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.output_dir = output_dir
        self.scheduler = scheduler
        
        # Loss functions
        self.ce_loss = nn.CrossEntropyLoss()
        self.adversarial_loss = AdversarialLoss(alpha=0.1)
        self.consistency_loss = ConsistencyLoss(beta=0.05)
        
        # Training history
        self.train_losses = []
        self.val_accuracies = []
        self.best_val_acc = 0.0
        self.patience_counter = 0
        self.early_stopping_patience = 10
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
    def train_epoch(self, train_loader, preprocessor):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0
        num_batches = 0
        
        pbar = tqdm(train_loader, desc="Training")
        for batch in pbar:
            # Move data to device
            face = batch['face'].to(self.device)
            voice = batch['voice'].to(self.device)
            behavioral = batch['behavioral'].to(self.device)
            labels = batch['label'].to(self.device)
            
            # Forward pass for clean data
            logits_clean, features_clean = self.model(face, voice, behavioral)
            
            # Generate adversarial examples
            face_adv, voice_adv, behavioral_adv = self._generate_adversarial_examples(
                face, voice, behavioral, labels
            )
            
            # Forward pass for adversarial data
            logits_adv, features_adv = self.model(face_adv, voice_adv, behavioral_adv)
            
            # Compute losses
            ce_loss = self.ce_loss(logits_clean, labels)
            adv_loss = self.adversarial_loss(logits_clean, logits_adv, labels)
            cons_loss = self.consistency_loss(features_clean, features_adv)
            
            total_loss_batch = ce_loss + adv_loss + cons_loss
            
            # Backward pass
            self.optimizer.zero_grad()
            total_loss_batch.backward()
            self.optimizer.step()
            
            total_loss += total_loss_batch.item()
            num_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{total_loss_batch.item():.4f}',
                'CE': f'{ce_loss.item():.4f}',
                'Adv': f'{adv_loss.item():.4f}',
                'Cons': f'{cons_loss.item():.4f}'
            })
        
        avg_loss = total_loss / num_batches
        self.train_losses.append(avg_loss)
        
        return avg_loss
    
    def validate_epoch(self, val_loader, preprocessor):
        """Validate for one epoch"""
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                face = batch['face'].to(self.device)
                voice = batch['voice'].to(self.device)
                behavioral = batch['behavioral'].to(self.device)
                labels = batch['label'].to(self.device)
                
                logits, _ = self.model(face, voice, behavioral)
                _, predicted = torch.max(logits.data, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        self.val_accuracies.append(accuracy)
        
        return accuracy
    
    def _generate_adversarial_examples(self, face, voice, behavioral, labels, epsilon=0.03, alpha=0.01, num_steps=10):
        """Generate adversarial examples using PGD"""
        # Create adversarial examples for each modality
        face_adv = face.clone().detach()
        voice_adv = voice.clone().detach()
        behavioral_adv = behavioral.clone().detach()
        
        # Enable gradients for adversarial generation
        face_adv.requires_grad_(True)
        voice_adv.requires_grad_(True)
        behavioral_adv.requires_grad_(True)
        
        for _ in range(num_steps):
            # Forward pass
            logits, _ = self.model(face_adv, voice_adv, behavioral_adv)
            loss = self.ce_loss(logits, labels)
            
            # Compute gradients
            grad_face = torch.autograd.grad(loss, face_adv, retain_graph=True)[0]
            grad_voice = torch.autograd.grad(loss, voice_adv, retain_graph=True)[0]
            grad_behavioral = torch.autograd.grad(loss, behavioral_adv, retain_graph=True)[0]
            
            # Update adversarial examples
            face_adv = face_adv + alpha * grad_face.sign()
            voice_adv = voice_adv + alpha * grad_voice.sign()
            behavioral_adv = behavioral_adv + alpha * grad_behavioral.sign()
            
            # Project to epsilon ball
            face_adv = torch.clamp(face_adv, face - epsilon, face + epsilon)
            voice_adv = torch.clamp(voice_adv, voice - epsilon, voice + epsilon)
            behavioral_adv = torch.clamp(behavioral_adv, behavioral - epsilon, behavioral + epsilon)
            
            # Clamp to valid ranges
            face_adv = torch.clamp(face_adv, -1, 1)
            voice_adv = torch.clamp(voice_adv, -3, 3)  # Typical MFCC range
            behavioral_adv = torch.clamp(behavioral_adv, -1, 1)
        
        return face_adv.detach(), voice_adv.detach(), behavioral_adv.detach()
    
    def train(self, train_loader, val_loader, preprocessor, num_epochs=100):
        """Full training loop"""
        print(f"Starting training for {num_epochs} epochs...")
        print(f"Device: {self.device}")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        start_time = time.time()
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            
            # Training
            train_loss = self.train_epoch(train_loader, preprocessor)
            
            # Validation
            val_acc = self.validate_epoch(val_loader, preprocessor)
            
            # Learning rate scheduling
            if self.scheduler:
                self.scheduler.step()
            
            # Early stopping check
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.patience_counter = 0
                self.save_checkpoint(epoch, is_best=True)
            else:
                self.patience_counter += 1
            
            # Print progress
            print(f"Train Loss: {train_loss:.4f}, Val Acc: {val_acc:.2f}%")
            print(f"Best Val Acc: {self.best_val_acc:.2f}%, Patience: {self.patience_counter}")
            
            # Early stopping
            if self.patience_counter >= self.early_stopping_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
        
        training_time = time.time() - start_time
        print(f"\nTraining completed in {training_time:.2f} seconds")
        print(f"Best validation accuracy: {self.best_val_acc:.2f}%")
        
        # Save final model
        self.save_checkpoint(epoch, is_best=False)
        
        # Plot training curves
        self.plot_training_curves()
        
        return self.train_losses, self.val_accuracies
    
    def save_checkpoint(self, epoch, is_best=False):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'val_accuracies': self.val_accuracies,
            'best_val_acc': self.best_val_acc
        }
        
        if self.scheduler:
            checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
        
        # Save regular checkpoint
        checkpoint_path = os.path.join(self.output_dir, f'checkpoint_epoch_{epoch}.pth')
        torch.save(checkpoint, checkpoint_path)
        
        # Save best model
        if is_best:
            best_path = os.path.join(self.output_dir, 'best_model.pth')
            torch.save(checkpoint, best_path)
            print(f"Best model saved to {best_path}")
    
    def load_checkpoint(self, checkpoint_path):
        """Load model checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if 'scheduler_state_dict' in checkpoint and self.scheduler:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        self.train_losses = checkpoint.get('train_losses', [])
        self.val_accuracies = checkpoint.get('val_accuracies', [])
        self.best_val_acc = checkpoint.get('best_val_acc', 0.0)
        
        return checkpoint['epoch']
    
    def plot_training_curves(self):
        """Plot training and validation curves"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        # Training loss
        ax1.plot(self.train_losses)
        ax1.set_title('Training Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.grid(True)
        
        # Validation accuracy
        ax2.plot(self.val_accuracies)
        ax2.set_title('Validation Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, 'training_curves.png'), dpi=300, bbox_inches='tight')
        plt.close()

class AdversarialLoss(nn.Module):
    def __init__(self, alpha=0.1):
        super(AdversarialLoss, self).__init__()
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        
    def forward(self, logits_clean, logits_adv, labels):
        # Standard classification loss
        ce_loss = self.ce_loss(logits_clean, labels)
        
        # Adversarial loss (encourage similar predictions)
        adv_loss = nn.functional.mse_loss(logits_clean, logits_adv)
        
        return ce_loss + self.alpha * adv_loss

class ConsistencyLoss(nn.Module):
    def __init__(self, beta=0.05):
        super(ConsistencyLoss, self).__init__()
        self.beta = beta
        
    def forward(self, features_clean, features_adv):
        # Encourage feature consistency between clean and adversarial inputs
        return self.beta * nn.functional.mse_loss(features_clean, features_adv)

if __name__ == '__main__':
    print("Testing Trainer...")
    
    # Create dummy model and data
    from model import ProposedModel
    from dataset import ResearchDataset
    from preprocessor import Preprocessor
    from torch.utils.data import DataLoader
    
    # Model
    model = ProposedModel(num_classes=100)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Data
    dataset = ResearchDataset(num_subjects=100, samples_per_subject=10)
    preprocessor = Preprocessor(augment=True)
    
    # DataLoader
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataset, batch_size=32, shuffle=False)
    
    # Trainer
    trainer = Trainer(model, optimizer, device, './test_output')
    
    # Test training for a few epochs
    trainer.train(train_loader, val_loader, preprocessor, num_epochs=2)
    
    print("Trainer test complete.")
