"""
Quantization methods for CACTUS.

This module implements quantization techniques including:
- FP16 and INT8 quantization
- Adversarial Weight Perturbation (AWP) as differentiable proxy
- Quantization-aware training utilities
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
from abc import ABC, abstractmethod


class QuantizationWrapper(nn.Module):
    """Wrapper that applies quantization to a model."""
    
    def __init__(self, model: nn.Module, precision: str = 'fp16'):
        """
        Args:
            model: PyTorch model to quantize
            precision: Quantization precision ('fp16', 'int8')
        """
        super().__init__()
        self.model = model
        self.precision = precision
        self._quantized = False
    
    def quantize(self):
        """Apply quantization to the model."""
        if self.precision == 'fp16':
            self.model = self.model.half()
        elif self.precision == 'int8':
            # For INT8, we simulate quantization by rounding weights
            # In practice, you'd use torch.quantization for actual INT8
            self._apply_int8_simulation()
        else:
            raise ValueError(f"Unsupported precision: {self.precision}")
        
        self._quantized = True
        return self
    
    def _apply_int8_simulation(self):
        """Simulate INT8 quantization by rounding weights."""
        with torch.no_grad():
            for param in self.model.parameters():
                if param.requires_grad:
                    # Simple symmetric quantization to [-127, 127]
                    scale = param.abs().max() / 127.0
                    quantized = torch.round(param / scale) * scale
                    param.data.copy_(quantized)
    
    def forward(self, x):
        if self.precision == 'fp16' and not self._quantized:
            x = x.half()
        
        return self.model(x)


class AdversarialWeightPerturbation:
    """
    Adversarial Weight Perturbation as differentiable proxy for quantization.
    
    This implements the AWP technique described in the CACTUS paper,
    which approximates quantization effects through adversarial perturbations.
    """
    
    def __init__(self, eta: float = 0.25, steps: int = 1, lr: float = 0.1):
        """
        Args:
            eta: Maximum perturbation magnitude (L∞ bound)
            steps: Number of gradient steps for finding perturbation
            lr: Learning rate for perturbation optimization
        """
        self.eta = eta
        self.steps = steps
        self.lr = lr
    
    def create_perturbation(self, model: nn.Module, inputs: torch.Tensor, 
                          targets: torch.Tensor, loss_fn) -> Dict[str, torch.Tensor]:
        """
        Create adversarial weight perturbation.
        
        Args:
            model: PyTorch model
            inputs: Input batch
            targets: Target labels
            loss_fn: Loss function to maximize
            
        Returns:
            Dictionary mapping parameter names to perturbations
        """
        perturbations = {}
        
        # Initialize perturbations
        for name, param in model.named_parameters():
            if param.requires_grad:
                perturbations[name] = torch.zeros_like(param.data)
        
        # Optimize perturbations
        for step in range(self.steps):
            # Apply current perturbations
            self._apply_perturbations(model, perturbations)
            
            # Forward pass
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            
            # Backward pass to get gradients w.r.t. perturbations
            loss.backward()
            
            # Update perturbations using gradient ascent
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if param.requires_grad and param.grad is not None:
                        # Update perturbation
                        grad_sign = param.grad.sign()
                        perturbations[name] += self.lr * grad_sign
                        
                        # Project to L∞ ball
                        perturbations[name] = torch.clamp(
                            perturbations[name], -self.eta, self.eta
                        )
            
            # Remove perturbations for next iteration
            self._remove_perturbations(model, perturbations)
            
            # Clear gradients
            model.zero_grad()
        
        return perturbations
    
    def _apply_perturbations(self, model: nn.Module, perturbations: Dict[str, torch.Tensor]):
        """Apply perturbations to model parameters."""
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in perturbations:
                    param.data += perturbations[name]
    
    def _remove_perturbations(self, model: nn.Module, perturbations: Dict[str, torch.Tensor]):
        """Remove perturbations from model parameters."""
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in perturbations:
                    param.data -= perturbations[name]


def quantize_model(model: nn.Module, precision: str = 'fp16') -> QuantizationWrapper:
    """
    Quantize a model to specified precision.
    
    Args:
        model: PyTorch model to quantize
        precision: Target precision ('fp16', 'int8')
        
    Returns:
        Quantized model wrapper
    """
    wrapper = QuantizationWrapper(model, precision)
    return wrapper.quantize()


def create_awp_perturbation(model: nn.Module, inputs: torch.Tensor, 
                          targets: torch.Tensor, loss_fn,
                          eta: float = 0.25) -> Dict[str, torch.Tensor]:
    """
    Create AWP perturbation for quantization approximation.
    
    Args:
        model: PyTorch model
        inputs: Input batch
        targets: Target labels  
        loss_fn: Loss function
        eta: Perturbation magnitude
        
    Returns:
        Dictionary of perturbations
    """
    awp = AdversarialWeightPerturbation(eta=eta)
    return awp.create_perturbation(model, inputs, targets, loss_fn)


def get_quantization_error(original_model: nn.Module, quantized_model: nn.Module,
                          test_inputs: torch.Tensor) -> float:
    """
    Compute quantization error between original and quantized models.
    
    Args:
        original_model: Original full-precision model
        quantized_model: Quantized model
        test_inputs: Test inputs for evaluation
        
    Returns:
        Mean squared error between outputs
    """
    original_model.eval()
    quantized_model.eval()
    
    with torch.no_grad():
        original_outputs = original_model(test_inputs)
        quantized_outputs = quantized_model(test_inputs)
        
        # Convert to same dtype for comparison
        if original_outputs.dtype != quantized_outputs.dtype:
            quantized_outputs = quantized_outputs.float()
        
        mse = F.mse_loss(original_outputs, quantized_outputs)
        
    return mse.item()


class QuantizationAwareTrainer:
    """Trainer for quantization-aware training with AWP."""
    
    def __init__(self, model: nn.Module, eta: float = 0.25, 
                 lambda_awp: float = 0.5):
        """
        Args:
            model: Model to train
            eta: AWP perturbation magnitude
            lambda_awp: Weight for AWP loss term
        """
        self.model = model
        self.awp = AdversarialWeightPerturbation(eta=eta)
        self.lambda_awp = lambda_awp
    
    def compute_loss(self, inputs: torch.Tensor, targets: torch.Tensor,
                    loss_fn) -> torch.Tensor:
        """
        Compute loss with AWP regularization.
        
        Args:
            inputs: Input batch
            targets: Target labels
            loss_fn: Loss function
            
        Returns:
            Combined loss with AWP term
        """
        # Standard loss
        outputs = self.model(inputs)
        standard_loss = loss_fn(outputs, targets)
        
        # AWP loss
        perturbations = self.awp.create_perturbation(
            self.model, inputs, targets, loss_fn
        )
        
        # Apply perturbations and compute loss
        self.awp._apply_perturbations(self.model, perturbations)
        perturbed_outputs = self.model(inputs)
        awp_loss = loss_fn(perturbed_outputs, targets)
        
        # Remove perturbations
        self.awp._remove_perturbations(self.model, perturbations)
        
        # Combined loss
        total_loss = (1 - self.lambda_awp) * standard_loss + self.lambda_awp * awp_loss
        
        return total_loss


def simulate_quantization_noise(tensor: torch.Tensor, precision: str, 
                               noise_scale: float = 0.1) -> torch.Tensor:
    """
    Add quantization noise to simulate quantization effects.
    
    Args:
        tensor: Input tensor
        precision: Target precision ('fp16', 'int8')
        noise_scale: Scale of noise to add
        
    Returns:
        Tensor with added quantization noise
    """
    if precision == 'fp16':
        # Add small amount of noise to simulate FP16 precision loss
        noise = torch.randn_like(tensor) * noise_scale * 1e-3
    elif precision == 'int8':
        # Add larger noise to simulate INT8 quantization
        noise = torch.randn_like(tensor) * noise_scale * 1e-2
    else:
        noise = torch.zeros_like(tensor)
    
    return tensor + noise


if __name__ == "__main__":
    # Test quantization methods
    print("Testing quantization methods...")
    
    # Create test model and data
    from ..models import create_cnn7_mnist
    
    model = create_cnn7_mnist()
    test_input = torch.randn(4, 1, 28, 28)
    test_targets = torch.randint(0, 10, (4,))
    
    # Test quantization wrapper
    print("Testing quantization wrapper...")
    fp16_model = quantize_model(model, 'fp16')
    int8_model = quantize_model(model, 'int8')
    
    # Test AWP
    print("Testing Adversarial Weight Perturbation...")
    awp = AdversarialWeightPerturbation(eta=0.1)
    loss_fn = nn.CrossEntropyLoss()
    
    perturbations = awp.create_perturbation(model, test_input, test_targets, loss_fn)
    print(f"Generated perturbations for {len(perturbations)} parameters")
    
    # Test quantization-aware trainer
    print("Testing quantization-aware trainer...")
    trainer = QuantizationAwareTrainer(model, eta=0.1)
    loss = trainer.compute_loss(test_input, test_targets, loss_fn)
    print(f"AWP loss: {loss.item():.4f}")
    
    print("Quantization tests passed!") 