"""
Proxy Prior Construction for Human-Prior Correction

This module implements methods to construct confusion priors from foundation models
when human annotations are not available, as described in the HPC paper.

Three main approaches:
1. CLIP-derived priors using vision-language alignment
2. Self-supervised priors (DINO, SimCLR) using visual features
3. Few-shot human priors combining limited annotations with proxy priors
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Optional, Tuple, Dict, Union
import clip
from torchvision import transforms
from sklearn.neighbors import NearestNeighbors
import warnings

try:
    import timm
    TIMM_AVAILABLE = True
except ImportError:
    TIMM_AVAILABLE = False
    warnings.warn("timm not available. Some self-supervised models may not work.")


class CLIPPriorConstructor:
    """
    Constructs confusion priors using CLIP's vision-language alignment.
    
    The confusion matrix is computed as:
    C_CLIP[i,j] = exp(sim(v_i, v_j)/τ) / Σ_k exp(sim(v_i, v_k)/τ)
    
    where v_i is the average CLIP embedding for class i.
    """
    
    def __init__(
        self,
        model_name: str = "ViT-B/32",
        temperature: float = 0.07,
        device: str = "cpu"
    ):
        """
        Initialize CLIP prior constructor.
        
        Args:
            model_name: CLIP model architecture
            temperature: Temperature τ for softmax (lower = more peaked)
            device: Device to run on
        """
        self.device = device
        self.temperature = temperature
        
        # Load CLIP model
        self.model, self.preprocess = clip.load(model_name, device=device)
        self.model.eval()
        
        # Standard prompt templates for text encoding
        self.prompt_templates = [
            "A photo of a {}.",
            "A picture of a {}.",
            "An image of a {}.",
            "A rendering of a {}.",
            "A cropped photo of a {}.",
            "The photo of a {}.",
            "A photo of one {}."
        ]
    
    def construct_text_based_prior(self, class_names: List[str]) -> torch.Tensor:
        """
        Construct confusion matrix using CLIP text encodings only.
        
        This approach requires only class names, no training images.
        
        Args:
            class_names: List of class names (e.g., ["airplane", "car", "bird"])
            
        Returns:
            Confusion matrix of shape (num_classes, num_classes)
        """
        num_classes = len(class_names)
        
        # Generate text embeddings for each class
        class_embeddings = []
        
        for class_name in class_names:
            # Generate prompts for this class
            prompts = [template.format(class_name) for template in self.prompt_templates]
            text_tokens = clip.tokenize(prompts).to(self.device)
            
            # Encode text
            with torch.no_grad():
                text_features = self.model.encode_text(text_tokens)
                text_features = F.normalize(text_features, dim=-1)
                
                # Average across prompts
                class_embedding = text_features.mean(dim=0)
                class_embeddings.append(class_embedding)
        
        class_embeddings = torch.stack(class_embeddings)
        
        # Compute pairwise cosine similarities
        similarities = torch.mm(class_embeddings, class_embeddings.t())
        
        # Apply temperature and softmax to get confusion matrix
        confusion_matrix = F.softmax(similarities / self.temperature, dim=1)
        
        return confusion_matrix
    
    def construct_image_based_prior(
        self,
        images_by_class: Dict[str, List[torch.Tensor]],
        class_names: List[str]
    ) -> torch.Tensor:
        """
        Construct confusion matrix using CLIP image encodings.
        
        Args:
            images_by_class: Dict mapping class names to lists of images
            class_names: Ordered list of class names
            
        Returns:
            Confusion matrix of shape (num_classes, num_classes)
        """
        class_embeddings = []
        
        for class_name in class_names:
            if class_name not in images_by_class:
                raise ValueError(f"No images provided for class {class_name}")
            
            images = images_by_class[class_name]
            image_embeddings = []
            
            for image in images:
                # Preprocess image
                if image.dim() == 3:
                    image = image.unsqueeze(0)
                
                with torch.no_grad():
                    image_features = self.model.encode_image(image.to(self.device))
                    image_features = F.normalize(image_features, dim=-1)
                    image_embeddings.append(image_features)
            
            # Average embeddings for this class
            class_embedding = torch.stack(image_embeddings).mean(dim=0)
            class_embeddings.append(class_embedding.squeeze())
        
        class_embeddings = torch.stack(class_embeddings)
        
        # Compute similarities and confusion matrix
        similarities = torch.mm(class_embeddings, class_embeddings.t())
        confusion_matrix = F.softmax(similarities / self.temperature, dim=1)
        
        return confusion_matrix


class SelfSupervisedPriorConstructor:
    """
    Constructs confusion priors using self-supervised visual features (DINO, SimCLR, etc.).
    
    The confusion matrix is computed using k-nearest neighbors in the feature space:
    C_SSL[i,j] = (1/|I_i|) Σ_{x∈I_i} I[NN_k(x) ∈ class j]
    """
    
    def __init__(
        self,
        model_name: str = "dinov2_vitb14",
        k_neighbors: int = 5,
        device: str = "cpu"
    ):
        """
        Initialize self-supervised prior constructor.
        
        Args:
            model_name: Model name (supports timm models, DINO variants)
            k_neighbors: Number of nearest neighbors for confusion computation
            device: Device to run on
        """
        self.device = device
        self.k_neighbors = k_neighbors
        
        if not TIMM_AVAILABLE:
            raise ImportError("timm is required for self-supervised models")
        
        # Load pre-trained model
        self.model = timm.create_model(model_name, pretrained=True, num_classes=0)
        self.model = self.model.to(device).eval()
        
        # Standard ImageNet preprocessing
        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def construct_prior(
        self,
        images_by_class: Dict[str, List[torch.Tensor]],
        class_names: List[str],
        max_images_per_class: int = 100
    ) -> torch.Tensor:
        """
        Construct confusion matrix using self-supervised features.
        
        Args:
            images_by_class: Dict mapping class names to lists of images
            class_names: Ordered list of class names
            max_images_per_class: Maximum images to use per class (for efficiency)
            
        Returns:
            Confusion matrix of shape (num_classes, num_classes)
        """
        # Extract features for all images
        all_features = []
        all_labels = []
        
        for class_idx, class_name in enumerate(class_names):
            if class_name not in images_by_class:
                raise ValueError(f"No images provided for class {class_name}")
            
            images = images_by_class[class_name][:max_images_per_class]
            
            for image in images:
                # Preprocess image
                if image.dim() == 3:
                    image = image.unsqueeze(0)
                
                # Apply preprocessing
                image = self.preprocess(image)
                
                # Extract features
                with torch.no_grad():
                    features = self.model(image.to(self.device))
                    features = F.normalize(features, dim=-1)
                    all_features.append(features.cpu().numpy())
                    all_labels.append(class_idx)
        
        all_features = np.vstack(all_features)
        all_labels = np.array(all_labels)
        
        # Build k-NN index
        knn = NearestNeighbors(n_neighbors=self.k_neighbors + 1, metric='cosine')
        knn.fit(all_features)
        
        # Compute confusion matrix
        num_classes = len(class_names)
        confusion_matrix = torch.zeros((num_classes, num_classes))
        
        for class_idx in range(num_classes):
            class_mask = all_labels == class_idx
            class_features = all_features[class_mask]
            
            if len(class_features) == 0:
                continue
            
            # Find neighbors for each image in this class
            _, neighbor_indices = knn.kneighbors(class_features)
            
            # Count neighbor class frequencies (excluding self)
            for neighbors in neighbor_indices:
                neighbor_labels = all_labels[neighbors[1:]]  # Exclude self (first neighbor)
                for neighbor_label in neighbor_labels:
                    confusion_matrix[class_idx, neighbor_label] += 1
        
        # Normalize rows
        confusion_matrix = F.normalize(confusion_matrix, p=1, dim=1)
        
        return confusion_matrix


class FewShotHumanPriorConstructor:
    """
    Constructs confusion priors by combining limited human annotations with proxy priors.
    
    The hybrid confusion matrix is:
    C_hybrid = γ * C_human + (1-γ) * C_proxy
    
    where γ weights the human contribution based on annotation count.
    """
    
    def __init__(self, min_annotations_for_human: int = 5):
        """
        Args:
            min_annotations_for_human: Minimum annotations needed to use human data
        """
        self.min_annotations_for_human = min_annotations_for_human
    
    def construct_hybrid_prior(
        self,
        human_annotations: Dict[str, List[int]],  # class -> list of human labels
        proxy_prior: torch.Tensor,
        class_names: List[str],
        human_weight_strategy: str = "proportional"
    ) -> torch.Tensor:
        """
        Construct hybrid confusion matrix combining human and proxy data.
        
        Args:
            human_annotations: Dict mapping class names to lists of human label indices
            proxy_prior: Pre-computed proxy confusion matrix
            class_names: Ordered list of class names
            human_weight_strategy: How to weight human vs proxy ("proportional", "threshold")
            
        Returns:
            Hybrid confusion matrix
        """
        num_classes = len(class_names)
        human_confusion = torch.zeros((num_classes, num_classes))
        
        # Compute human confusion matrix from annotations
        for class_idx, class_name in enumerate(class_names):
            if class_name in human_annotations:
                annotations = human_annotations[class_name]
                
                if len(annotations) >= self.min_annotations_for_human:
                    # Build empirical distribution
                    for label in annotations:
                        if 0 <= label < num_classes:
                            human_confusion[class_idx, label] += 1
                    
                    # Normalize
                    if human_confusion[class_idx].sum() > 0:
                        human_confusion[class_idx] /= human_confusion[class_idx].sum()
        
        # Compute mixing weights
        if human_weight_strategy == "proportional":
            # Weight by annotation count
            weights = torch.zeros(num_classes)
            for class_idx, class_name in enumerate(class_names):
                if class_name in human_annotations:
                    count = len(human_annotations[class_name])
                    weights[class_idx] = min(count / 50.0, 1.0)  # Cap at 50 annotations
        
        elif human_weight_strategy == "threshold":
            # Binary threshold
            weights = torch.zeros(num_classes)
            for class_idx, class_name in enumerate(class_names):
                if class_name in human_annotations:
                    count = len(human_annotations[class_name])
                    weights[class_idx] = 1.0 if count >= self.min_annotations_for_human else 0.0
        
        else:
            raise ValueError(f"Unknown weighting strategy: {human_weight_strategy}")
        
        # Combine human and proxy priors
        hybrid_confusion = torch.zeros_like(proxy_prior)
        for class_idx in range(num_classes):
            gamma = weights[class_idx]
            if human_confusion[class_idx].sum() > 0:
                hybrid_confusion[class_idx] = gamma * human_confusion[class_idx] + (1 - gamma) * proxy_prior[class_idx]
            else:
                hybrid_confusion[class_idx] = proxy_prior[class_idx]
        
        return hybrid_confusion


def construct_cifar10_clip_prior(device: str = "cpu") -> torch.Tensor:
    """
    Convenience function to construct CLIP prior for CIFAR-10.
    
    Returns:
        CIFAR-10 confusion matrix using CLIP text embeddings
    """
    cifar10_classes = [
        "airplane", "automobile", "bird", "cat", "deer",
        "dog", "frog", "horse", "ship", "truck"
    ]
    
    clip_constructor = CLIPPriorConstructor(device=device)
    return clip_constructor.construct_text_based_prior(cifar10_classes)


def construct_cifar100_clip_prior(device: str = "cpu") -> torch.Tensor:
    """
    Convenience function to construct CLIP prior for CIFAR-100.
    
    Returns:
        CIFAR-100 confusion matrix using CLIP text embeddings
    """
    # CIFAR-100 fine labels
    cifar100_classes = [
        'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 
        'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 
        'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 
        'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 
        'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 
        'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
        'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
        'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
        'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
        'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
        'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
        'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
        'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
        'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
        'worm'
    ]
    
    clip_constructor = CLIPPriorConstructor(device=device)
    return clip_constructor.construct_text_based_prior(cifar100_classes)


# Example usage and testing
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Test CLIP prior construction for CIFAR-10
    print("Constructing CIFAR-10 CLIP prior...")
    cifar10_prior = construct_cifar10_clip_prior(device)
    print(f"CIFAR-10 prior shape: {cifar10_prior.shape}")
    print(f"Row sums (should be ~1.0): {cifar10_prior.sum(dim=1)}")
    
    # Show some semantic similarities
    class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
    print("\nTop confusions for some classes:")
    for i, class_name in enumerate(class_names[:5]):
        top_confusions = torch.topk(cifar10_prior[i], k=3)
        print(f"{class_name}: {[class_names[j] for j in top_confusions.indices]} "
              f"({top_confusions.values.numpy()})")
    
    # Test few-shot constructor
    print("\nTesting few-shot human prior...")
    # Simulate some human annotations
    human_annotations = {
        "cat": [3, 3, 3, 5, 3],  # Mostly cat, some dog
        "dog": [5, 5, 3, 5, 5],  # Mostly dog, some cat
        "bird": [2, 2, 0, 2],    # Mostly bird, some airplane
    }
    
    few_shot_constructor = FewShotHumanPriorConstructor()
    hybrid_prior = few_shot_constructor.construct_hybrid_prior(
        human_annotations, 
        cifar10_prior, 
        class_names
    )
    
    print(f"Hybrid prior shape: {hybrid_prior.shape}")
    print("Human-influenced rows:")
    for class_name in human_annotations.keys():
        i = class_names.index(class_name)
        top_confusions = torch.topk(hybrid_prior[i], k=3)
        print(f"{class_name}: {[class_names[j] for j in top_confusions.indices]} "
              f"({top_confusions.values.numpy()})")
