"""
Conformal Prediction Integration with Human-Prior Correction

This module implements conformal prediction methods that can be combined
with HPC for enhanced uncertainty quantification and coverage guarantees.
Includes:
- Standard conformal prediction
- Adaptive conformal prediction
- HPC-aware conformal sets
- Risk-controlling prediction sets
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple, List, Dict, Optional, Callable
import math
from scipy.stats import binom
import warnings


class ConformalPredictor:
    """
    Base conformal prediction class with standard CP methods.
    
    Implements split conformal prediction with various non-conformity scores
    that can be applied after HPC calibration.
    """
    
    def __init__(
        self,
        alpha: float = 0.1,
        score_function: str = "adaptive_prediction_sets"
    ):
        """
        Initialize conformal predictor.
        
        Args:
            alpha: Miscoverage level (1-alpha coverage guarantee)
            score_function: Non-conformity score function type
        """
        self.alpha = alpha
        self.score_function = score_function
        self.quantile = None
        self.is_fitted = False
    
    def _compute_nonconformity_scores(
        self,
        probabilities: torch.Tensor,
        targets: torch.Tensor,
        score_type: str = "adaptive_prediction_sets"
    ) -> torch.Tensor:
        """
        Compute non-conformity scores for calibration data.
        
        Args:
            probabilities: Predicted probabilities (N, K)
            targets: True labels (N,)
            score_type: Type of non-conformity score
            
        Returns:
            Non-conformity scores (N,)
        """
        if score_type == "simple":
            # Simple score: 1 - P(y_true)
            true_class_probs = probabilities[torch.arange(len(targets)), targets]
            return 1.0 - true_class_probs
            
        elif score_type == "adaptive_prediction_sets":
            # Adaptive Prediction Sets (APS) score
            sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=1)
            
            # Find position of true class in sorted order
            true_positions = torch.zeros(len(targets), dtype=torch.long)
            for i, (target, indices) in enumerate(zip(targets, sorted_indices)):
                true_positions[i] = torch.where(indices == target)[0][0]
            
            # Compute APS scores
            scores = torch.zeros(len(targets))
            for i in range(len(targets)):
                pos = true_positions[i]
                if pos == 0:
                    scores[i] = torch.rand(1).item()  # Randomization for ties
                else:
                    scores[i] = cumsum_probs[i, pos - 1] + torch.rand(1).item() * probabilities[i, targets[i]]
            
            return scores
            
        elif score_type == "regularized_adaptive":
            # Regularized APS with penalty for large sets
            base_scores = self._compute_nonconformity_scores(probabilities, targets, "adaptive_prediction_sets")
            
            # Add regularization based on entropy
            entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-8), dim=1)
            max_entropy = math.log(probabilities.shape[1])
            normalized_entropy = entropy / max_entropy
            
            regularization = 0.1 * normalized_entropy
            return base_scores + regularization
            
        else:
            raise ValueError(f"Unknown score type: {score_type}")
    
    def fit(
        self,
        cal_probabilities: torch.Tensor,
        cal_targets: torch.Tensor
    ):
        """
        Fit conformal predictor on calibration data.
        
        Args:
            cal_probabilities: Calibration probabilities (N_cal, K)
            cal_targets: Calibration labels (N_cal,)
        """
        # Compute non-conformity scores
        scores = self._compute_nonconformity_scores(
            cal_probabilities, cal_targets, self.score_function
        )
        
        # Compute quantile
        n = len(scores)
        q_level = math.ceil((n + 1) * (1 - self.alpha)) / n
        self.quantile = torch.quantile(scores, q_level)
        
        self.is_fitted = True
        print(f"Conformal predictor fitted with quantile: {self.quantile:.4f}")
    
    def predict_sets(
        self,
        probabilities: torch.Tensor,
        return_sizes: bool = False
    ) -> Tuple[List[List[int]], Optional[torch.Tensor]]:
        """
        Generate conformal prediction sets.
        
        Args:
            probabilities: Test probabilities (N_test, K)
            return_sizes: Whether to return set sizes
            
        Returns:
            (prediction_sets, set_sizes)
        """
        if not self.is_fitted:
            raise ValueError("Must fit conformal predictor before prediction")
        
        n_test = probabilities.shape[0]
        num_classes = probabilities.shape[1]
        
        prediction_sets = []
        set_sizes = []
        
        for i in range(n_test):
            probs = probabilities[i]
            
            if self.score_function == "simple":
                # Include all classes with probability >= 1 - quantile
                threshold = 1.0 - self.quantile
                prediction_set = torch.where(probs >= threshold)[0].tolist()
                
            elif self.score_function in ["adaptive_prediction_sets", "regularized_adaptive"]:
                # APS-style prediction sets
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumsum_probs = torch.cumsum(sorted_probs, dim=0)
                
                # Find classes to include
                prediction_set = []
                for j, (prob, class_idx) in enumerate(zip(sorted_probs, sorted_indices)):
                    if j == 0:
                        score = torch.rand(1).item()
                    else:
                        score = cumsum_probs[j-1] + torch.rand(1).item() * prob
                    
                    if score <= self.quantile:
                        prediction_set.append(class_idx.item())
                    else:
                        break
                
                # Ensure at least one class is included
                if not prediction_set:
                    prediction_set = [sorted_indices[0].item()]
            
            prediction_sets.append(prediction_set)
            set_sizes.append(len(prediction_set))
        
        if return_sizes:
            return prediction_sets, torch.tensor(set_sizes)
        return prediction_sets, None


class AdaptiveConformalPredictor:
    """
    Adaptive conformal prediction that adjusts to distribution shift.
    
    Updates the quantile based on recent performance to maintain
    coverage under distribution drift.
    """
    
    def __init__(
        self,
        alpha: float = 0.1,
        gamma: float = 0.005,
        update_freq: int = 100
    ):
        """
        Initialize adaptive conformal predictor.
        
        Args:
            alpha: Target miscoverage level
            gamma: Learning rate for quantile updates
            update_freq: Frequency of quantile updates
        """
        self.alpha = alpha
        self.gamma = gamma
        self.update_freq = update_freq
        
        self.quantile = 0.5  # Initial quantile
        self.step_count = 0
        self.recent_errors = []
        self.is_fitted = False
    
    def _update_quantile(self, error: float):
        """Update quantile based on recent error."""
        # Gradient step
        gradient = self.alpha - error
        self.quantile = self.quantile + self.gamma * gradient
        self.quantile = max(0.0, min(1.0, self.quantile))  # Clamp to [0, 1]
    
    def fit(
        self,
        cal_probabilities: torch.Tensor,
        cal_targets: torch.Tensor,
        score_function: str = "adaptive_prediction_sets"
    ):
        """Initial fitting on calibration data."""
        # Use standard conformal prediction for initial quantile
        standard_cp = ConformalPredictor(alpha=self.alpha, score_function=score_function)
        standard_cp.fit(cal_probabilities, cal_targets)
        self.quantile = standard_cp.quantile.item()
        self.is_fitted = True
    
    def predict_and_update(
        self,
        probabilities: torch.Tensor,
        targets: Optional[torch.Tensor] = None
    ) -> Tuple[List[List[int]], Optional[torch.Tensor]]:
        """
        Generate predictions and update quantile if targets provided.
        
        Args:
            probabilities: Test probabilities (N_test, K)
            targets: Optional true labels for quantile updating
            
        Returns:
            (prediction_sets, set_sizes)
        """
        if not self.is_fitted:
            raise ValueError("Must fit adaptive conformal predictor before prediction")
        
        # Generate prediction sets using current quantile
        cp = ConformalPredictor(alpha=self.alpha)
        cp.quantile = torch.tensor(self.quantile)
        cp.is_fitted = True
        
        prediction_sets, set_sizes = cp.predict_sets(probabilities, return_sizes=True)
        
        # Update quantile if targets provided
        if targets is not None:
            for i, (pred_set, target) in enumerate(zip(prediction_sets, targets)):
                # Check if true label is in prediction set
                error = 1.0 if target.item() not in pred_set else 0.0
                self.recent_errors.append(error)
                
                self.step_count += 1
                
                # Update quantile periodically
                if self.step_count % self.update_freq == 0:
                    recent_error_rate = np.mean(self.recent_errors[-self.update_freq:])
                    self._update_quantile(recent_error_rate)
                    print(f"Updated quantile to {self.quantile:.4f} (error rate: {recent_error_rate:.4f})")
        
        return prediction_sets, set_sizes


class HPCAwareConformalPredictor:
    """
    HPC-aware conformal prediction that incorporates human prior information
    into the non-conformity scores and set construction.
    """
    
    def __init__(
        self,
        alpha: float = 0.1,
        human_weight: float = 0.3,
        use_human_scores: bool = True
    ):
        """
        Initialize HPC-aware conformal predictor.
        
        Args:
            alpha: Miscoverage level
            human_weight: Weight for human prior influence
            use_human_scores: Whether to use human distributions in scoring
        """
        self.alpha = alpha
        self.human_weight = human_weight
        self.use_human_scores = use_human_scores
        self.quantile = None
        self.is_fitted = False
    
    def _compute_human_aware_scores(
        self,
        probabilities: torch.Tensor,
        targets: torch.Tensor,
        human_distributions: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute non-conformity scores that incorporate human distributions.
        
        Args:
            probabilities: Model probabilities (N, K)
            targets: True labels (N,)
            human_distributions: Human label distributions (N, K)
            
        Returns:
            Human-aware non-conformity scores (N,)
        """
        # Standard model-based scores
        model_scores = 1.0 - probabilities[torch.arange(len(targets)), targets]
        
        if not self.use_human_scores:
            return model_scores
        
        # Human-based scores (how much human probability mass on true class)
        human_scores = 1.0 - human_distributions[torch.arange(len(targets)), targets]
        
        # Combine model and human scores
        combined_scores = (1 - self.human_weight) * model_scores + self.human_weight * human_scores
        
        return combined_scores
    
    def fit(
        self,
        cal_probabilities: torch.Tensor,
        cal_targets: torch.Tensor,
        cal_human_distributions: torch.Tensor
    ):
        """
        Fit HPC-aware conformal predictor.
        
        Args:
            cal_probabilities: Calibration probabilities (N_cal, K)
            cal_targets: Calibration labels (N_cal,)
            cal_human_distributions: Calibration human distributions (N_cal, K)
        """
        # Compute human-aware non-conformity scores
        scores = self._compute_human_aware_scores(
            cal_probabilities, cal_targets, cal_human_distributions
        )
        
        # Compute quantile
        n = len(scores)
        q_level = math.ceil((n + 1) * (1 - self.alpha)) / n
        self.quantile = torch.quantile(scores, q_level)
        
        self.is_fitted = True
        print(f"HPC-aware conformal predictor fitted with quantile: {self.quantile:.4f}")
    
    def predict_sets(
        self,
        probabilities: torch.Tensor,
        human_distributions: Optional[torch.Tensor] = None,
        return_human_coverage: bool = False
    ) -> Tuple[List[List[int]], Dict]:
        """
        Generate HPC-aware prediction sets.
        
        Args:
            probabilities: Test probabilities (N_test, K)
            human_distributions: Test human distributions (N_test, K)
            return_human_coverage: Whether to compute human coverage metrics
            
        Returns:
            (prediction_sets, metrics_dict)
        """
        if not self.is_fitted:
            raise ValueError("Must fit HPC-aware conformal predictor before prediction")
        
        n_test = probabilities.shape[0]
        prediction_sets = []
        set_sizes = []
        human_coverages = []
        
        for i in range(n_test):
            model_probs = probabilities[i]
            
            if human_distributions is not None:
                human_probs = human_distributions[i]
                # Combine model and human probabilities for set construction
                combined_probs = (1 - self.human_weight) * model_probs + self.human_weight * human_probs
            else:
                combined_probs = model_probs
                human_probs = None
            
            # Create prediction set based on combined probabilities
            # Include classes with combined probability >= 1 - quantile
            threshold = 1.0 - self.quantile
            prediction_set = torch.where(combined_probs >= threshold)[0].tolist()
            
            # Ensure at least one class is included
            if not prediction_set:
                prediction_set = [torch.argmax(combined_probs).item()]
            
            prediction_sets.append(prediction_set)
            set_sizes.append(len(prediction_set))
            
            # Compute human coverage if requested
            if return_human_coverage and human_probs is not None:
                human_coverage = sum(human_probs[j].item() for j in prediction_set)
                human_coverages.append(human_coverage)
        
        metrics = {
            'set_sizes': torch.tensor(set_sizes),
            'mean_set_size': torch.tensor(set_sizes).float().mean().item(),
            'human_coverages': torch.tensor(human_coverages) if human_coverages else None
        }
        
        return prediction_sets, metrics


class RiskControllingPredictor:
    """
    Risk-controlling prediction sets that bound different types of risk
    beyond just coverage (e.g., class-conditional coverage, fairness).
    """
    
    def __init__(
        self,
        alpha: float = 0.1,
        risk_type: str = "coverage",
        class_conditional: bool = False
    ):
        """
        Initialize risk-controlling predictor.
        
        Args:
            alpha: Risk level to control
            risk_type: Type of risk ("coverage", "size", "fairness")
            class_conditional: Whether to ensure class-conditional coverage
        """
        self.alpha = alpha
        self.risk_type = risk_type
        self.class_conditional = class_conditional
        self.quantiles = {}
        self.is_fitted = False
    
    def fit(
        self,
        cal_probabilities: torch.Tensor,
        cal_targets: torch.Tensor,
        cal_groups: Optional[torch.Tensor] = None
    ):
        """
        Fit risk-controlling predictor.
        
        Args:
            cal_probabilities: Calibration probabilities (N_cal, K)
            cal_targets: Calibration labels (N_cal,)
            cal_groups: Optional group memberships for fairness (N_cal,)
        """
        if self.class_conditional:
            # Fit separate quantiles for each class
            num_classes = cal_probabilities.shape[1]
            for class_idx in range(num_classes):
                class_mask = (cal_targets == class_idx)
                if class_mask.sum() > 0:
                    class_probs = cal_probabilities[class_mask]
                    class_targets = cal_targets[class_mask]
                    
                    # Compute scores for this class
                    scores = 1.0 - class_probs[torch.arange(len(class_targets)), class_targets]
                    
                    # Compute class-specific quantile
                    n = len(scores)
                    if n > 0:
                        q_level = math.ceil((n + 1) * (1 - self.alpha)) / n
                        self.quantiles[class_idx] = torch.quantile(scores, q_level)
                    else:
                        self.quantiles[class_idx] = torch.tensor(0.5)
        else:
            # Single quantile for all classes
            scores = 1.0 - cal_probabilities[torch.arange(len(cal_targets)), cal_targets]
            n = len(scores)
            q_level = math.ceil((n + 1) * (1 - self.alpha)) / n
            self.quantiles['global'] = torch.quantile(scores, q_level)
        
        self.is_fitted = True
        print(f"Risk-controlling predictor fitted with {len(self.quantiles)} quantile(s)")
    
    def predict_sets(
        self,
        probabilities: torch.Tensor,
        predicted_classes: Optional[torch.Tensor] = None
    ) -> Tuple[List[List[int]], torch.Tensor]:
        """
        Generate risk-controlling prediction sets.
        
        Args:
            probabilities: Test probabilities (N_test, K)
            predicted_classes: Predicted classes for class-conditional control
            
        Returns:
            (prediction_sets, set_sizes)
        """
        if not self.is_fitted:
            raise ValueError("Must fit risk-controlling predictor before prediction")
        
        n_test = probabilities.shape[0]
        prediction_sets = []
        
        for i in range(n_test):
            probs = probabilities[i]
            
            if self.class_conditional:
                # Use class-specific quantile
                if predicted_classes is not None:
                    pred_class = predicted_classes[i].item()
                else:
                    pred_class = torch.argmax(probs).item()
                
                if pred_class in self.quantiles:
                    quantile = self.quantiles[pred_class]
                else:
                    quantile = self.quantiles.get('global', torch.tensor(0.5))
            else:
                quantile = self.quantiles['global']
            
            # Create prediction set
            threshold = 1.0 - quantile
            prediction_set = torch.where(probs >= threshold)[0].tolist()
            
            # Ensure at least one class
            if not prediction_set:
                prediction_set = [torch.argmax(probs).item()]
            
            prediction_sets.append(prediction_set)
        
        set_sizes = torch.tensor([len(s) for s in prediction_sets])
        return prediction_sets, set_sizes


# Utility functions for conformal prediction evaluation
def evaluate_coverage(
    prediction_sets: List[List[int]],
    true_labels: torch.Tensor
) -> Dict[str, float]:
    """
    Evaluate conformal prediction coverage and efficiency.
    
    Args:
        prediction_sets: List of prediction sets
        true_labels: Ground truth labels
        
    Returns:
        Dictionary with coverage metrics
    """
    n = len(prediction_sets)
    coverages = []
    set_sizes = []
    
    for i, (pred_set, true_label) in enumerate(zip(prediction_sets, true_labels)):
        is_covered = true_label.item() in pred_set
        coverages.append(float(is_covered))
        set_sizes.append(len(pred_set))
    
    metrics = {
        'coverage': np.mean(coverages),
        'mean_set_size': np.mean(set_sizes),
        'median_set_size': np.median(set_sizes),
        'std_set_size': np.std(set_sizes),
        'efficiency': np.mean(coverages) / np.mean(set_sizes)  # Coverage per unit size
    }
    
    return metrics


# Example usage and testing
if __name__ == "__main__":
    print("Testing conformal prediction with HPC integration...")
    
    # Create synthetic data
    n_cal = 500
    n_test = 200
    num_classes = 10
    
    # Generate synthetic probabilities and labels
    torch.manual_seed(42)
    cal_logits = torch.randn(n_cal, num_classes)
    cal_probs = F.softmax(cal_logits, dim=1)
    cal_targets = torch.randint(0, num_classes, (n_cal,))
    
    test_logits = torch.randn(n_test, num_classes)
    test_probs = F.softmax(test_logits, dim=1)
    test_targets = torch.randint(0, num_classes, (n_test,))
    
    # Generate synthetic human distributions
    cal_human_dists = F.softmax(torch.randn(n_cal, num_classes), dim=1)
    test_human_dists = F.softmax(torch.randn(n_test, num_classes), dim=1)
    
    print(f"Data: {n_cal} calibration, {n_test} test samples")
    
    # Test standard conformal prediction
    print("\n1. Testing standard conformal prediction...")
    cp = ConformalPredictor(alpha=0.1, score_function="adaptive_prediction_sets")
    cp.fit(cal_probs, cal_targets)
    pred_sets, _ = cp.predict_sets(test_probs, return_sizes=True)
    
    metrics = evaluate_coverage(pred_sets, test_targets)
    print(f"   Coverage: {metrics['coverage']:.3f}")
    print(f"   Mean set size: {metrics['mean_set_size']:.2f}")
    
    # Test adaptive conformal prediction
    print("\n2. Testing adaptive conformal prediction...")
    acp = AdaptiveConformalPredictor(alpha=0.1, gamma=0.01)
    acp.fit(cal_probs, cal_targets)
    
    # Simulate online prediction with feedback
    adaptive_pred_sets = []
    for i in range(min(50, n_test)):  # Test first 50 samples
        batch_probs = test_probs[i:i+1]
        batch_targets = test_targets[i:i+1]
        pred_sets_batch, _ = acp.predict_and_update(batch_probs, batch_targets)
        adaptive_pred_sets.extend(pred_sets_batch)
    
    metrics = evaluate_coverage(adaptive_pred_sets, test_targets[:len(adaptive_pred_sets)])
    print(f"   Adaptive coverage: {metrics['coverage']:.3f}")
    print(f"   Final quantile: {acp.quantile:.3f}")
    
    # Test HPC-aware conformal prediction
    print("\n3. Testing HPC-aware conformal prediction...")
    hpc_cp = HPCAwareConformalPredictor(alpha=0.1, human_weight=0.3)
    hpc_cp.fit(cal_probs, cal_targets, cal_human_dists)
    
    hpc_pred_sets, hpc_metrics = hpc_cp.predict_sets(
        test_probs, test_human_dists, return_human_coverage=True
    )
    
    metrics = evaluate_coverage(hpc_pred_sets, test_targets)
    print(f"   HPC coverage: {metrics['coverage']:.3f}")
    print(f"   HPC mean set size: {metrics['mean_set_size']:.2f}")
    if hpc_metrics['human_coverages'] is not None:
        print(f"   Mean human coverage: {hpc_metrics['human_coverages'].mean():.3f}")
    
    print("\nConformal prediction integration test completed!")
