"""
X-Sample Contrastive Loss (X-CLR) for Continual Learning

Implementation of X-CLR (ICLR 2025) as a text-side extension of contrastive learning.
Instead of a hard "one positive" target, X-CLR uses a soft target distribution
derived from a similarity graph and minimizes cross-entropy between the target
distribution and the learned similarity distribution.

Key Features:
- Soft graph construction based on taxonomy (same model_id, same domain) or model-card embeddings
- Prompt embeddings extracted from router model's hidden states (not separate vision model)
- Compatible with self-instruct data where near-duplicate prompts may map to different models
- Apply-to policy supporting "all" or "replay_only" modes for continual learning

Reference:
    X-CLR: X-Sample Contrastive Loss (ICLR 2025)
    
Usage:
    L_total = L_supervised + λ * L_xclr
"""

import os
import json
import hashlib
import random
from collections import deque
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================================================================
# Domain Canonicalization
# ============================================================================

def canonicalize_domain(domain_str: Optional[str]) -> str:
    """
    Canonicalize a domain string for consistent comparison.
    
    Normalizes: lowercase, strip, collapse whitespace, handle None/"unknown".
    
    Args:
        domain_str: Domain string (may be None)
        
    Returns:
        Canonicalized domain string
    """
    if domain_str is None:
        return "unknown"
    
    # Convert to string if not already
    domain_str = str(domain_str)
    
    # Lowercase, strip, collapse whitespace
    domain_str = domain_str.lower().strip()
    # Collapse multiple whitespace to single space
    import re
    domain_str = re.sub(r'\s+', ' ', domain_str)
    
    # Handle empty string
    if not domain_str:
        return "unknown"
    
    return domain_str


# ============================================================================
# Prompt Embedding Projection Head
# ============================================================================

class PromptProjectionHead(nn.Module):
    """
    Projection head for prompt embeddings.
    
    Takes pooled hidden states and projects to a lower-dimensional
    embedding space suitable for contrastive learning.
    """
    
    def __init__(
        self,
        hidden_dim: int,
        proj_dim: int = 128,
        use_layernorm: bool = True,
    ):
        """
        Initialize the projection head.
        
        Args:
            hidden_dim: Dimension of the input hidden states
            proj_dim: Dimension of the output projection
            use_layernorm: Whether to apply LayerNorm after projection
        """
        super().__init__()
        self.proj = nn.Linear(hidden_dim, proj_dim, bias=False)
        self.use_layernorm = use_layernorm
        if use_layernorm:
            self.layernorm = nn.LayerNorm(proj_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Project and normalize embeddings.
        
        Args:
            x: Input tensor of shape [B, hidden_dim]
            
        Returns:
            L2-normalized embeddings of shape [B, proj_dim]
        """
        x = self.proj(x)
        if self.use_layernorm:
            x = self.layernorm(x)
        # L2 normalize to unit length
        x = F.normalize(x, p=2, dim=-1)
        return x

# ============================================================================
# Soft Graph Construction
# ============================================================================

def build_taxonomy_soft_graph(
    batch_model_ids: List[str],
    batch_domains: List[str],
    alpha_domain: float = 0.3,
    device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor:
    """
    Build a taxonomy-based soft similarity graph.
    
    Graph weights:
    - 1.0 if same gold model_id
    - alpha_domain if same domain but different model_id  
    - 0.0 otherwise
    
    Diagonal is set to -inf to exclude self-comparisons.
    
    Args:
        batch_model_ids: List of model IDs for each example in batch
        batch_domains: List of domain strings for each example in batch
        alpha_domain: Similarity weight for same-domain, different-model pairs
        device: Device for the output tensor (defaults to CPU for test compatibility)
        
    Returns:
        Soft graph tensor [B, B] with diagonal set to -inf
    """
    batch_size = len(batch_model_ids)
    
    # Default to CPU for consistency in tests; will be moved to correct device in compute_xclr_loss
    if device is None:
        device = torch.device('cpu')
    elif isinstance(device, str):
        device = torch.device(device)
    
    # Canonicalize domains for consistent comparison
    batch_domains_canonical = [canonicalize_domain(d) for d in batch_domains]
    
    # Initialize graph
    G_soft = torch.zeros(batch_size, batch_size, device=device)
    
    for i in range(batch_size):
        for j in range(batch_size):
            if i == j:
                continue  # Diagonal will be masked later
            
            if batch_model_ids[i] == batch_model_ids[j]:
                # Same model_id -> strongest similarity
                G_soft[i, j] = 1.0
            elif batch_domains_canonical[i] == batch_domains_canonical[j]:
                # Same domain, different model -> partial similarity
                G_soft[i, j] = alpha_domain
            # else: 0.0 (already initialized)
    
    # Set diagonal to -inf for softmax masking
    G_soft.fill_diagonal_(float('-inf'))
    
    return G_soft


def build_prompt_similarity_soft_graph(
    retrieved_scores: List[List[float]],
    device: Optional[Union[torch.device, str]] = None,
    taxonomy_blend: float = 0.0,
    anchor_domains: Optional[List[str]] = None,
    anchor_models: Optional[List[str]] = None,
    candidate_domains: Optional[List[List[str]]] = None,
    candidate_models: Optional[List[List[str]]] = None,
    alpha_domain: float = 0.3,
) -> torch.Tensor:
    """
    Build a prompt similarity-based soft graph from retrieval scores.
    
    This graph uses similarity scores from prompt retrieval (BM25, embedding, etc.)
    as the target distribution weights. Each anchor has its own row with scores
    for its retrieved candidates.
    
    Optionally blends with taxonomy scores (domain/model matching) to ensure
    same-domain examples get higher scores even if prompt similarity is lower.
    
    Args:
        retrieved_scores: List of lists of similarity scores [N_anchors, N_candidates]
                         Each inner list contains scores for one anchor's candidates
        device: Device for the output tensor
        taxonomy_blend: Weight for blending taxonomy graph (0 = pure prompt similarity, 1 = pure taxonomy)
                       When > 0, blends prompt similarity scores with domain/model matching scores
        anchor_domains: Optional list of anchor domains [N_anchors] (required if taxonomy_blend > 0)
        anchor_models: Optional list of anchor model IDs [N_anchors] (required if taxonomy_blend > 0)
        candidate_domains: Optional list of lists of candidate domains [N_anchors, N_candidates] 
                          (required if taxonomy_blend > 0)
        candidate_models: Optional list of lists of candidate model IDs [N_anchors, N_candidates]
                         (required if taxonomy_blend > 0)
        alpha_domain: Alpha for taxonomy graph - weight for same-domain, different-model pairs
        
    Returns:
        Soft graph tensor [N_anchors, N_candidates] with scores (no diagonal masking needed
        since anchors and candidates are separate)
    """
    n_anchors = len(retrieved_scores)
    if n_anchors == 0:
        raise ValueError("retrieved_scores cannot be empty")
    
    n_candidates = len(retrieved_scores[0]) if retrieved_scores else 0
    if n_candidates == 0:
        raise ValueError("Each anchor must have at least one candidate")
    
    # Verify all anchors have same number of candidates
    for i, scores in enumerate(retrieved_scores):
        if len(scores) != n_candidates:
            raise ValueError(
                f"Anchor {i} has {len(scores)} candidates, expected {n_candidates}"
            )
    
    # Default to CPU for consistency in tests
    if device is None:
        device = torch.device('cpu')
    elif isinstance(device, str):
        device = torch.device(device)
    
    # Build graph from prompt similarity scores
    G_soft = torch.zeros(n_anchors, n_candidates, device=device)
    
    for i, scores in enumerate(retrieved_scores):
        for j, score in enumerate(scores):
            G_soft[i, j] = float(score)
    
    # Blend with taxonomy if requested
    if taxonomy_blend > 0:
        if anchor_domains is None or anchor_models is None or candidate_domains is None or candidate_models is None:
            raise ValueError(
                "anchor_domains, anchor_models, candidate_domains, and candidate_models "
                "are required when taxonomy_blend > 0"
            )
        
        # Build taxonomy graph [N_anchors, N_candidates]
        G_taxonomy = torch.zeros(n_anchors, n_candidates, device=device)
        
        for i in range(n_anchors):
            anchor_domain = anchor_domains[i] if i < len(anchor_domains) else None
            anchor_model = anchor_models[i] if i < len(anchor_models) else None
            anchor_domain_canonical = canonicalize_domain(anchor_domain) if anchor_domain else "unknown"
            
            cand_domains = candidate_domains[i] if i < len(candidate_domains) else []
            cand_models = candidate_models[i] if i < len(candidate_models) else []
            
            for j in range(n_candidates):
                if j < len(cand_domains) and j < len(cand_models):
                    cand_domain = cand_domains[j]
                    cand_model = cand_models[j]
                    
                    if anchor_model and cand_model and anchor_model == cand_model:
                        # Same model -> strongest similarity
                        G_taxonomy[i, j] = 1.0
                    elif cand_domain and canonicalize_domain(cand_domain) == anchor_domain_canonical:
                        # Same domain, different model -> partial similarity
                        G_taxonomy[i, j] = alpha_domain
                    # else: 0.0 (already initialized)
        
        # Normalize prompt similarity scores to [0, 1] range for blending
        # Find max score per anchor to normalize (handles different retriever score ranges)
        G_prompt_norm = G_soft.clone()
        for i in range(n_anchors):
            row_max = G_soft[i].max().item()
            row_min = G_soft[i].min().item()
            if row_max > row_min:
                # Normalize to [0, 1] range
                G_prompt_norm[i] = (G_soft[i] - row_min) / (row_max - row_min)
            # If all scores are the same, keep as-is (will blend uniformly)
        
        # Blend: (1 - blend) * prompt_similarity + blend * taxonomy
        G_soft = (1 - taxonomy_blend) * G_prompt_norm + taxonomy_blend * G_taxonomy
    
    return G_soft


def build_modelcard_soft_graph(
    batch_model_ids: List[str],
    modelcard_embed_cache: Dict[str, torch.Tensor],
    taxonomy_blend: float = 0.0,
    batch_domains: Optional[List[str]] = None,
    alpha_domain: float = 0.3,
    device: Optional[torch.device] = None,
) -> torch.Tensor:
    """
    Build a model-card embedding-based soft similarity graph.
    
    Computes cosine similarity between model card embeddings for each pair.
    Optionally blends with taxonomy weights.
    
    Args:
        batch_model_ids: List of model IDs for each example in batch
        modelcard_embed_cache: Dict mapping model_id -> embedding tensor [D]
        taxonomy_blend: Weight for blending taxonomy graph (0 = pure modelcard, 1 = pure taxonomy)
        batch_domains: Optional domains for taxonomy blending
        alpha_domain: Alpha for taxonomy graph (if blending)
        device: Device for the output tensor
        
    Returns:
        Soft graph tensor [B, B] with diagonal set to -inf
    """
    batch_size = len(batch_model_ids)
    
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Get embeddings for each model_id in batch
    embed_dim = None
    embeddings = []
    
    for model_id in batch_model_ids:
        if model_id in modelcard_embed_cache:
            emb = modelcard_embed_cache[model_id]
            if embed_dim is None:
                embed_dim = emb.shape[-1]
            embeddings.append(emb)
        else:
            # Fallback: zero embedding (will have 0 similarity with everything)
            if embed_dim is None:
                # Try to infer from any cached embedding
                for cached_emb in modelcard_embed_cache.values():
                    embed_dim = cached_emb.shape[-1]
                    break
                if embed_dim is None:
                    embed_dim = 768  # Default SBERT dimension
            embeddings.append(torch.zeros(embed_dim))
    
    # Stack into [B, D]
    embeddings = torch.stack([e.to(device) if isinstance(e, torch.Tensor) else torch.tensor(e, device=device) 
                             for e in embeddings])
    
    # Normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=-1)
    
    # Compute pairwise cosine similarities
    G_soft = torch.mm(embeddings, embeddings.T)  # [B, B]
    
    # Blend with taxonomy if requested
    if taxonomy_blend > 0 and batch_domains is not None:
        G_taxonomy = build_taxonomy_soft_graph(
            batch_model_ids, batch_domains, alpha_domain, device
        )
        # Note: G_taxonomy has -inf on diagonal, so we need to handle carefully
        # Replace -inf with 0 for blending, then re-mask
        G_taxonomy_blend = G_taxonomy.clone()
        G_taxonomy_blend[G_taxonomy_blend == float('-inf')] = 0
        
        G_soft = (1 - taxonomy_blend) * G_soft + taxonomy_blend * G_taxonomy_blend
    
    # Set diagonal to -inf for softmax masking
    G_soft.fill_diagonal_(float('-inf'))
    
    return G_soft

# ============================================================================
# Utility Functions
# ============================================================================

def enable_hidden_states(model: nn.Module):
    """
    Enable output_hidden_states for a model.
    
    Args:
        model: The model to configure
    """
    if hasattr(model, 'config'):
        model.config.output_hidden_states = True
    if hasattr(model, 'model') and hasattr(model.model, 'config'):
        model.model.config.output_hidden_states = True
    # For PEFT models
    if hasattr(model, 'base_model'):
        if hasattr(model.base_model, 'config'):
            model.base_model.config.output_hidden_states = True
        if hasattr(model.base_model, 'model') and hasattr(model.base_model.model, 'config'):
            model.base_model.model.config.output_hidden_states = True


def get_last_hidden_states(model_outputs: Any) -> Optional[torch.Tensor]:
    """
    Extract last-layer hidden states from model outputs.
    
    Args:
        model_outputs: Output from model forward pass
        
    Returns:
        Last hidden states [B, T, D] or None if not available
    """
    if hasattr(model_outputs, 'hidden_states') and model_outputs.hidden_states is not None:
        # hidden_states is a tuple of (embedding_output, *layer_outputs)
        return model_outputs.hidden_states[-1]
    return None




