"""
Lipschitz constant tracker along the training trajectory.

Correct implementation: both gradients are computed on the SAME batch.

At each step:
    1. Save the current batch
    2. After optimizer.step() compute gradient on the saved batch
    3. Compare with gradient before step()
    
    L_t = ||g(θ_t, batch) - g(θ_{t-1}, batch)||_* / ||θ_t - θ_{t-1}||

Usage:
    tracker = LipschitzTracker(model, criterion, device)
    
    for images, labels in loader:
        optimizer.zero_grad()
        loss = criterion(model(images), labels)
        loss.backward()
        
        # Call AFTER backward(), BEFORE step()
        tracker.save_state(model, images, labels)
        
        optimizer.step()
        
        # Call AFTER step()
        L_step = tracker.compute_lipschitz(model)
    
    stats = tracker.get_stats()
"""

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Optional, Tuple


class LipschitzTracker:
    """
    Lipschitz constant tracker along the training trajectory.
    Correct version: gradients on the same batch.
    """
    
    def __init__(self, criterion, device: str):
        """
        Args:
            criterion: loss function (e.g., nn.CrossEntropyLoss())
            device: device ('cuda:0' or 'cpu')
        """
        self.criterion = criterion
        self.device = device
        
        # Saved state
        self.prev_params: Optional[torch.Tensor] = None
        self.prev_grad: Optional[torch.Tensor] = None
        self.saved_batch: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
        
        # Fixed batch for Lipschitz measurement (if set, always used)
        self.fixed_batch: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
        
        # History of L estimates (full history)
        self.L_history_l2: List[float] = []
        self.L_history_linf: List[float] = []
        
        # Current epoch buffer
        self.L_epoch_l2: List[float] = []
        self.L_epoch_linf: List[float] = []
        
        self.step_count = 0
    
    def _flatten_params(self, model: nn.Module) -> torch.Tensor:
        """Collect parameters into a single vector"""
        return torch.cat([p.data.view(-1).clone() for p in model.parameters()])
    
    def _flatten_grads(self, model: nn.Module) -> torch.Tensor:
        """Collect gradients into a single vector"""
        grads = []
        for p in model.parameters():
            if p.grad is not None:
                grads.append(p.grad.data.view(-1).clone())
            else:
                grads.append(torch.zeros(p.numel(), device=self.device))
        return torch.cat(grads)
    
    def set_fixed_batch(self, images: torch.Tensor, labels: torch.Tensor):
        """
        Set a fixed batch for Lipschitz measurement.
        After setting, this batch will be used for all measurements.
        
        Args:
            images: batch input data
            labels: batch labels
        """
        self.fixed_batch = (images.detach().clone(), labels.detach().clone())
    
    def save_state(self, model: nn.Module, images: torch.Tensor = None, labels: torch.Tensor = None):
        """
        Save state before optimizer.step()
        Call AFTER backward(), BEFORE step()
        
        If a fixed batch is set (via set_fixed_batch),
        then images and labels are ignored.
        
        Args:
            model: model (already with computed gradients)
            images: current batch input data (ignored if fixed_batch exists)
            labels: current batch labels (ignored if fixed_batch exists)
        """
        self.prev_params = self._flatten_params(model)
        
        # If there is a fixed batch, compute gradient on it
        if self.fixed_batch is not None:
            images_fixed, labels_fixed = self.fixed_batch
            
            # Save gradients of current batch (needed for optimizer.step())
            current_grads = []
            for p in model.parameters():
                if p.grad is not None:
                    current_grads.append(p.grad.detach().clone())
                else:
                    current_grads.append(None)
            
            # Compute gradient on fixed batch with current parameters
            model.zero_grad()
            outputs = model(images_fixed)
            loss = self.criterion(outputs, labels_fixed)
            loss.backward()
            self.prev_grad = self._flatten_grads(model)
            self.saved_batch = self.fixed_batch  # Use fixed batch
            
            # Restore gradients of current batch for optimizer.step()
            for p, saved_grad in zip(model.parameters(), current_grads):
                if saved_grad is not None:
                    p.grad = saved_grad
                else:
                    p.grad = None
        else:
            # Old behavior: use provided batch
            if images is None or labels is None:
                raise ValueError("images and labels must be provided if fixed_batch is not set")
            self.prev_grad = self._flatten_grads(model)
            self.saved_batch = (images.detach().clone(), labels.detach().clone())
    
    def compute_lipschitz(self, model: nn.Module) -> Optional[Dict[str, float]]:
        """
        Compute Lipschitz constant after optimizer.step()
        Call AFTER step()
        
        Args:
            model: model (with updated parameters)
        
        Returns:
            Dict with L estimates or None (if no saved state)
        """
        if self.prev_params is None or self.saved_batch is None:
            return None
        
        # Current parameters (after step)
        curr_params = self._flatten_params(model)
        
        # Compute gradient on SAVED batch with NEW parameters
        images, labels = self.saved_batch
        
        model.zero_grad()
        outputs = model(images)
        loss = self.criterion(outputs, labels)
        loss.backward()
        
        curr_grad = self._flatten_grads(model)
        
        # Parameter difference (optimizer step)
        delta_params = curr_params - self.prev_params
        # Gradient difference (on the SAME batch)
        delta_grad = curr_grad - self.prev_grad
        
        # L2 → L2
        delta_params_l2 = delta_params.norm(2).item()
        delta_grad_l2 = delta_grad.norm(2).item()
        
        # L∞ → L1 (dual)
        delta_params_linf = delta_params.abs().max().item()
        delta_grad_l1 = delta_grad.abs().sum().item()
        
        L_l2 = None
        L_linf = None
        
        if delta_params_l2 > 1e-12:
            L_l2 = delta_grad_l2 / delta_params_l2
            self.L_history_l2.append(L_l2)
            self.L_epoch_l2.append(L_l2)
        
        if delta_params_linf > 1e-12:
            L_linf = delta_grad_l1 / delta_params_linf
            self.L_history_linf.append(L_linf)
            self.L_epoch_linf.append(L_linf)
        
        self.step_count += 1
        
        # Clear saved state
        self.saved_batch = None
        
        return {
            'L_l2': L_l2,
            'L_linf': L_linf,
        }
    
    def get_epoch_stats(self) -> Dict[str, Dict[str, float]]:
        """Get statistics for current epoch and clear buffer"""
        stats = {}
        
        if self.L_epoch_l2:
            arr = np.array(self.L_epoch_l2)
            stats['L2'] = {
                'min': float(np.min(arr)),
                'max': float(np.max(arr)),
                'mean': float(np.mean(arr)),
            }
        
        if self.L_epoch_linf:
            arr = np.array(self.L_epoch_linf)
            stats['Linf'] = {
                'min': float(np.min(arr)),
                'max': float(np.max(arr)),
                'mean': float(np.mean(arr)),
            }
        
        stats['n_steps'] = len(self.L_epoch_l2)
        
        # Clear epoch buffer
        self.L_epoch_l2 = []
        self.L_epoch_linf = []
        
        return stats
    
    def get_stats(self) -> Dict[str, Dict[str, float]]:
        """Get statistics for entire history"""
        stats = {}
        
        if self.L_history_l2:
            arr = np.array(self.L_history_l2)
            stats['L2'] = {
                'min': float(np.min(arr)),
                'max': float(np.max(arr)),
                'mean': float(np.mean(arr)),
            }
        
        if self.L_history_linf:
            arr = np.array(self.L_history_linf)
            stats['Linf'] = {
                'min': float(np.min(arr)),
                'max': float(np.max(arr)),
                'mean': float(np.mean(arr)),
            }
        
        stats['n_steps'] = len(self.L_history_l2)
        
        return stats
