"""
CACTUS training algorithm implementation.

This module implements the core CACTUS (Compression Aware Certified Training
Using network Sets) algorithm that jointly optimizes for accuracy, certified
robustness, and compressibility.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import copy
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
from tqdm import tqdm

from ..compression import (
    create_pruning_mask, apply_pruning_mask, 
    AdversarialWeightPerturbation, get_sparsity
)
from ..certification import SABRLoss, compute_sabr_loss
from .compression_set import CompressionSet
from .trainer_utils import TrainingConfig, evaluate_model


class CACTUSLoss(nn.Module):
    """
    CACTUS loss function implementing Equation (4) from the paper:
    
    L_CACTUS = (1/|C(f_θ)|) Σ [λ * L_std(C_ψ(x), y) + (1-λ) * L_cert(C_ψ(x), y)]
    """
    
    def __init__(self, compression_set: CompressionSet, lmbda: float = 0.5,
                 epsilon: float = 0.1, certified_loss_type: str = 'sabr'):
        """
        Args:
            compression_set: Set of compression configurations
            lmbda: Weight balancing standard and certified losses
            epsilon: L∞ perturbation budget for certified loss
            certified_loss_type: Type of certified loss ('sabr', 'ibp')
        """
        super().__init__()
        self.compression_set = compression_set
        self.lmbda = lmbda
        self.epsilon = epsilon
        self.certified_loss_type = certified_loss_type
        
        # Initialize certified loss
        if certified_loss_type == 'sabr':
            self.certified_loss = SABRLoss(epsilon=epsilon)
        else:
            raise ValueError(f"Unknown certified loss type: {certified_loss_type}")
    
    def forward(self, model: nn.Module, inputs: torch.Tensor,
                targets: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        """
        Compute CACTUS loss across compression set.
        
        Args:
            model: Neural network model
            inputs: Input batch
            targets: Target labels
            
        Returns:
            Tuple of (total_loss, loss_breakdown)
        """
        total_loss = 0.0
        loss_breakdown = {
            'standard_losses': [],
            'certified_losses': [],
            'compression_configs': []
        }
        
        # Refresh compression set with current model weights
        self.compression_set.refresh(model)
        
        # Iterate over compression configurations
        for config in self.compression_set.get_configs():
            # Create compressed model
            compressed_model = self._create_compressed_model(model, config)
            
            # Compute standard loss
            outputs = compressed_model(inputs)
            std_loss = F.cross_entropy(outputs, targets)
            
            # Compute certified loss
            if config['type'] == 'quantization':
                # For quantization, use AWP as differentiable proxy
                cert_loss = self._compute_awp_loss(
                    model, inputs, targets, config['eta']
                )
            else:
                # For pruning, use SABR directly
                cert_loss = self.certified_loss(compressed_model, inputs, targets)
            
            # Combine losses
            combined_loss = self.lmbda * std_loss + (1 - self.lmbda) * cert_loss
            total_loss += combined_loss
            
            # Store for logging
            loss_breakdown['standard_losses'].append(std_loss.item())
            loss_breakdown['certified_losses'].append(cert_loss.item())
            loss_breakdown['compression_configs'].append(config)
        
        # Average over compression set
        total_loss /= len(self.compression_set)
        
        return total_loss, loss_breakdown
    
    def _create_compressed_model(self, model: nn.Module, config: Dict) -> nn.Module:
        """Create compressed version of model based on config."""
        if config['type'] == 'pruning':
            # Apply pruning mask
            compressed_model = copy.deepcopy(model)
            apply_pruning_mask(compressed_model, config['mask'])
            return compressed_model
        
        elif config['type'] == 'quantization':
            # For quantization, return original model (AWP handles compression)
            return model
        
        elif config['type'] == 'full':
            # Return uncompressed model
            return model
        
        else:
            raise ValueError(f"Unknown compression type: {config['type']}")
    
    def _compute_awp_loss(self, model: nn.Module, inputs: torch.Tensor,
                         targets: torch.Tensor, eta: float) -> torch.Tensor:
        """Compute AWP loss for quantization proxy."""
        awp = AdversarialWeightPerturbation(eta=eta)
        
        # Standard loss
        outputs = model(inputs)
        std_loss = F.cross_entropy(outputs, targets)
        
        # Certified loss with AWP
        cert_loss = compute_sabr_loss(model, inputs, targets, self.epsilon)
        
        # Combined loss
        return self.lmbda * std_loss + (1 - self.lmbda) * cert_loss


class CACTUSTrainer:
    """
    CACTUS trainer implementing Algorithm 1 from the paper.
    
    This trainer jointly optimizes for accuracy, certified robustness,
    and compressibility using the CACTUS loss function.
    """
    
    def __init__(self, model: nn.Module, config: TrainingConfig):
        """
        Args:
            model: Neural network model to train
            config: Training configuration
        """
        self.model = model
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Move model to device
        self.model.to(self.device)
        
        # Create compression set
        self.compression_set = CompressionSet(
            compression_configs=config.compression_configs,
            sampling_strategy=config.compression_sampling
        )
        
        # Initialize CACTUS loss
        self.cactus_loss = CACTUSLoss(
            compression_set=self.compression_set,
            lmbda=config.lambda_std,
            epsilon=config.epsilon,
            certified_loss_type=config.certified_loss_type
        )
        
        # Initialize optimizer
        self.optimizer = self._create_optimizer()
        
        # Initialize scheduler
        self.scheduler = self._create_scheduler()
        
        # Training state
        self.current_epoch = 0
        self.training_history = {
            'train_losses': [],
            'train_acc': [],
            'val_acc': [],
            'cert_acc': []
        }
    
    def _create_optimizer(self) -> optim.Optimizer:
        """Create optimizer based on config."""
        if self.config.optimizer == 'adam':
            return optim.Adam(
                self.model.parameters(),
                lr=self.config.learning_rate,
                weight_decay=self.config.weight_decay
            )
        elif self.config.optimizer == 'sgd':
            return optim.SGD(
                self.model.parameters(),
                lr=self.config.learning_rate,
                momentum=self.config.momentum,
                weight_decay=self.config.weight_decay
            )
        else:
            raise ValueError(f"Unknown optimizer: {self.config.optimizer}")
    
    def _create_scheduler(self) -> Optional[optim.lr_scheduler._LRScheduler]:
        """Create learning rate scheduler."""
        if self.config.scheduler == 'step':
            return optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=self.config.step_size,
                gamma=self.config.gamma
            )
        elif self.config.scheduler == 'cosine':
            return optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=self.config.epochs
            )
        else:
            return None
    
    def train_epoch(self, train_loader: DataLoader) -> Dict[str, float]:
        """Train for one epoch."""
        self.model.train()
        
        total_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {self.current_epoch}')
        
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Compute CACTUS loss
            loss, loss_breakdown = self.cactus_loss(self.model, inputs, targets)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping if specified
            if self.config.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.config.grad_clip
                )
            
            # Update parameters
            self.optimizer.step()
            
            # Statistics
            total_loss += loss.item()
            
            # Compute accuracy on original model
            with torch.no_grad():
                outputs = self.model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            if batch_idx % 10 == 0:
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{100.*correct/total:.2f}%'
                })
        
        # Update scheduler
        if self.scheduler:
            self.scheduler.step()
        
        return {
            'loss': total_loss / len(train_loader),
            'accuracy': 100. * correct / total
        }
    
    def evaluate(self, val_loader: DataLoader, 
                compute_certified: bool = True) -> Dict[str, float]:
        """Evaluate model on validation set."""
        return evaluate_model(
            self.model, val_loader, self.device,
            epsilon=self.config.epsilon if compute_certified else None
        )
    
    def train(self, train_loader: DataLoader, 
              val_loader: Optional[DataLoader] = None) -> Dict:
        """
        Full training loop.
        
        Args:
            train_loader: Training data loader
            val_loader: Optional validation data loader
            
        Returns:
            Training history dictionary
        """
        print(f"Starting CACTUS training for {self.config.epochs} epochs...")
        print(f"Compression configs: {len(self.compression_set.configs)}")
        print(f"Device: {self.device}")
        
        for epoch in range(self.config.epochs):
            self.current_epoch = epoch
            
            # Train epoch
            train_metrics = self.train_epoch(train_loader)
            
            # Evaluate
            if val_loader is not None:
                val_metrics = self.evaluate(val_loader)
                
                print(f"Epoch {epoch}: "
                      f"Loss={train_metrics['loss']:.4f}, "
                      f"Train Acc={train_metrics['accuracy']:.2f}%, "
                      f"Val Acc={val_metrics['standard_accuracy']:.2f}%, "
                      f"Cert Acc={val_metrics.get('certified_accuracy', 0):.2f}%")
                
                # Store history
                self.training_history['val_acc'].append(val_metrics['standard_accuracy'])
                self.training_history['cert_acc'].append(
                    val_metrics.get('certified_accuracy', 0)
                )
            else:
                print(f"Epoch {epoch}: "
                      f"Loss={train_metrics['loss']:.4f}, "
                      f"Train Acc={train_metrics['accuracy']:.2f}%")
            
            self.training_history['train_losses'].append(train_metrics['loss'])
            self.training_history['train_acc'].append(train_metrics['accuracy'])
        
        print("Training completed!")
        return self.training_history
    
    def save_checkpoint(self, filepath: str):
        """Save training checkpoint."""
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config,
            'training_history': self.training_history
        }
        
        if self.scheduler:
            checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
        
        torch.save(checkpoint, filepath)
        print(f"Checkpoint saved to {filepath}")
    
    def load_checkpoint(self, filepath: str):
        """Load training checkpoint."""
        checkpoint = torch.load(filepath, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.current_epoch = checkpoint['epoch']
        self.training_history = checkpoint['training_history']
        
        if self.scheduler and 'scheduler_state_dict' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        print(f"Checkpoint loaded from {filepath}")


if __name__ == "__main__":
    # Test CACTUS trainer
    print("Testing CACTUS trainer...")
    
    from ..models import create_cnn7_mnist
    from .trainer_utils import TrainingConfig
    
    # Create test model and config
    model = create_cnn7_mnist()
    config = TrainingConfig(
        epochs=2,
        learning_rate=0.001,
        compression_configs=[
            {'type': 'full'},
            {'type': 'pruning', 'method': 'global_l1', 'sparsity': 0.5}
        ]
    )
    
    # Create trainer
    trainer = CACTUSTrainer(model, config)
    
    # Create dummy data
    train_data = [(torch.randn(4, 1, 28, 28), torch.randint(0, 10, (4,))) for _ in range(5)]
    train_loader = DataLoader(train_data, batch_size=4)
    
    # Test training
    history = trainer.train(train_loader)
    
    print("CACTUS trainer test completed!")
    print(f"Final train accuracy: {history['train_acc'][-1]:.2f}%") 