"""
Human-Prior Correction (HPC) Core Implementation

This module implements the main HPC algorithm as described in:
"Human-Prior Correction: Scalable Post-hoc Calibration that Aligns Vision Models with Human Uncertainty"

Key equation: p' = normalize(p ⊙ C_α[y_pred,:])
where C_α = (1-α)I + αC
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, Union, Tuple
import warnings


class HumanPriorCorrection:
    """
    Human-Prior Correction (HPC) calibration method.
    
    HPC corrects model predictions by incorporating structured human confusion patterns
    through a principled Bayesian framework that minimizes:
    min_p KL(p||human) + λ·KL(p||model)
    
    The closed-form solution is: p* ∝ human^α · model^(1-α)
    """
    
    def __init__(
        self,
        confusion_matrix: torch.Tensor,
        alpha: float = 0.3,
        temperature: float = 1.0,
        adaptive_alpha: bool = False,
        gating_threshold: float = 0.1,
        device: str = 'cpu'
    ):
        """
        Initialize HPC calibrator.
        
        Args:
            confusion_matrix: Human confusion matrix C of shape (K, K)
                             C[i,j] = P(human predicts j | true class i)
            alpha: Mixing parameter α ∈ [0,1]. Higher values increase human prior influence
            temperature: Temperature parameter β for logit scaling before HPC
            adaptive_alpha: Whether to use instance-adaptive α(x)
            gating_threshold: Threshold γ for disagreement-based gating
            device: Device to run computations on
        """
        self.device = device
        self.confusion_matrix = confusion_matrix.to(device)
        self.alpha = alpha
        self.temperature = temperature
        self.adaptive_alpha = adaptive_alpha
        self.gating_threshold = gating_threshold
        
        # Validate confusion matrix
        self._validate_confusion_matrix()
        
        # Precompute regularized confusion matrix
        self.regularized_matrix = self._compute_regularized_matrix(alpha)
        
        # Initialize adaptive alpha parameters if needed
        if adaptive_alpha:
            self.alpha_net = None  # Will be initialized when features are available
            
    def _validate_confusion_matrix(self):
        """Validate that confusion matrix is properly normalized."""
        K = self.confusion_matrix.shape[0]
        assert self.confusion_matrix.shape == (K, K), "Confusion matrix must be square"
        
        # Check row normalization (within tolerance)
        row_sums = self.confusion_matrix.sum(dim=1)
        if not torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-3):
            warnings.warn("Confusion matrix rows don't sum to 1. Normalizing...")
            self.confusion_matrix = F.normalize(self.confusion_matrix, p=1, dim=1)
            
    def _compute_regularized_matrix(self, alpha: float) -> torch.Tensor:
        """
        Compute regularized confusion matrix: C_α = (1-α)I + αC
        
        Args:
            alpha: Mixing parameter
            
        Returns:
            Regularized confusion matrix
        """
        K = self.confusion_matrix.shape[0]
        identity = torch.eye(K, device=self.device)
        return (1 - alpha) * identity + alpha * self.confusion_matrix
    
    def forward(
        self,
        logits: torch.Tensor,
        features: Optional[torch.Tensor] = None,
        return_info: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
        """
        Apply HPC correction to model logits.
        
        Args:
            logits: Raw model logits of shape (batch_size, num_classes)
            features: Optional features for adaptive alpha (batch_size, feature_dim)
            return_info: Whether to return additional information
            
        Returns:
            Corrected probabilities, optionally with info dict
        """
        # Apply temperature scaling
        scaled_logits = logits / self.temperature
        model_probs = F.softmax(scaled_logits, dim=1)
        
        # Get predicted classes
        y_pred = torch.argmax(model_probs, dim=1)
        
        # Compute instance-specific alpha if adaptive
        if self.adaptive_alpha and features is not None:
            alpha_values = self._compute_adaptive_alpha(features, model_probs)
        else:
            alpha_values = torch.full((logits.shape[0],), self.alpha, device=self.device)
        
        # Apply HPC correction for each sample
        corrected_probs = []
        info_dict = {'alpha_values': alpha_values.cpu().numpy()} if return_info else {}
        
        for i in range(logits.shape[0]):
            # Get confusion pattern for predicted class
            confusion_row = self.confusion_matrix[y_pred[i]]
            
            # Compute regularized confusion row for this sample
            alpha_i = alpha_values[i]
            reg_confusion_row = (1 - alpha_i) * torch.eye(self.confusion_matrix.shape[0], device=self.device)[y_pred[i]] + alpha_i * confusion_row
            
            # Apply element-wise multiplication and normalize
            corrected_prob = model_probs[i] * reg_confusion_row
            corrected_prob = corrected_prob / corrected_prob.sum()
            corrected_probs.append(corrected_prob)
        
        corrected_probs = torch.stack(corrected_probs)
        
        if return_info:
            return corrected_probs, info_dict
        return corrected_probs
    
    def _compute_adaptive_alpha(
        self,
        features: torch.Tensor,
        model_probs: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute instance-adaptive α(x) with disagreement gating.
        
        Args:
            features: Input features (batch_size, feature_dim)
            model_probs: Model probabilities (batch_size, num_classes)
            
        Returns:
            Adaptive alpha values (batch_size,)
        """
        # Initialize adaptive network if needed
        if self.alpha_net is None:
            self.alpha_net = AdaptiveAlphaNet(features.shape[1]).to(self.device)
        
        # Compute base adaptive alpha
        alpha_base = self.alpha_net(features)
        
        # Compute disagreement gating
        y_pred = torch.argmax(model_probs, dim=1)
        human_probs = self.confusion_matrix[y_pred]
        
        # KL divergence between model and human predictions
        kl_div = F.kl_div(
            torch.log(model_probs + 1e-8),
            human_probs,
            reduction='none'
        ).sum(dim=1)
        
        # Apply exponential gating
        gate = torch.minimum(
            torch.ones_like(kl_div),
            torch.exp(-self.gating_threshold * kl_div)
        )
        
        return alpha_base * gate
    
    def set_alpha(self, alpha: float):
        """Update the alpha parameter and recompute regularized matrix."""
        self.alpha = alpha
        self.regularized_matrix = self._compute_regularized_matrix(alpha)
    
    def set_temperature(self, temperature: float):
        """Update temperature parameter."""
        self.temperature = temperature


class AdaptiveAlphaNet(torch.nn.Module):
    """
    Simple neural network for computing adaptive α(x).
    
    Architecture: Linear -> ReLU -> Linear -> Sigmoid
    """
    
    def __init__(self, input_dim: int, hidden_dim: int = 64):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1),
            torch.nn.Sigmoid()
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x).squeeze(-1)


def temperature_scale_logits(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """
    Apply temperature scaling to logits.
    
    Args:
        logits: Raw logits (batch_size, num_classes)
        temperature: Temperature parameter β
        
    Returns:
        Temperature-scaled probabilities
    """
    return F.softmax(logits / temperature, dim=1)


def find_optimal_temperature(
    logits: torch.Tensor,
    targets: torch.Tensor,
    temperature_range: Tuple[float, float] = (0.1, 5.0),
    num_temps: int = 100
) -> float:
    """
    Find optimal temperature for temperature scaling using grid search.
    
    Args:
        logits: Validation logits (N, K)
        targets: True labels (N,)
        temperature_range: Range of temperatures to search
        num_temps: Number of temperature values to try
        
    Returns:
        Optimal temperature value
    """
    temperatures = torch.linspace(temperature_range[0], temperature_range[1], num_temps)
    best_temp = 1.0
    best_nll = float('inf')
    
    for temp in temperatures:
        scaled_probs = temperature_scale_logits(logits, temp.item())
        nll = F.nll_loss(torch.log(scaled_probs + 1e-8), targets).item()
        
        if nll < best_nll:
            best_nll = nll
            best_temp = temp.item()
    
    return best_temp


# Example usage and testing
if __name__ == "__main__":
    # Create synthetic confusion matrix for CIFAR-10
    K = 10
    confusion_matrix = torch.eye(K) * 0.7
    
    # Add semantic confusions
    # Animals: cat(3) <-> dog(5), bird(2) <-> deer(4)
    confusion_matrix[3, 5] = 0.25  # cat -> dog
    confusion_matrix[5, 3] = 0.20  # dog -> cat
    confusion_matrix[2, 0] = 0.15  # bird -> airplane
    
    # Vehicles: automobile(1) <-> truck(9)
    confusion_matrix[1, 9] = 0.20  # automobile -> truck
    confusion_matrix[9, 1] = 0.15  # truck -> automobile
    
    # Normalize rows
    confusion_matrix = F.normalize(confusion_matrix, p=1, dim=1)
    
    # Initialize HPC
    hpc = HumanPriorCorrection(confusion_matrix, alpha=0.3, temperature=1.2)
    
    # Test with synthetic logits
    batch_size = 32
    logits = torch.randn(batch_size, K) * 2
    
    # Apply HPC correction
    corrected_probs = hpc.forward(logits)
    
    print(f"Original logits shape: {logits.shape}")
    print(f"Corrected probabilities shape: {corrected_probs.shape}")
    print(f"Probabilities sum to 1: {torch.allclose(corrected_probs.sum(dim=1), torch.ones(batch_size))}")
    
    # Test with adaptive alpha
    features = torch.randn(batch_size, 512)  # Typical CNN feature size
    hpc_adaptive = HumanPriorCorrection(
        confusion_matrix, 
        alpha=0.3, 
        adaptive_alpha=True
    )
    
    corrected_probs_adaptive, info = hpc_adaptive.forward(
        logits, 
        features=features, 
        return_info=True
    )
    
    print(f"Adaptive alpha values range: [{info['alpha_values'].min():.3f}, {info['alpha_values'].max():.3f}]")
