"""
Domain + Model-Aware Coreset Replay Sampler

This module implements a diversity-preserving replay buffer selection strategy
for continual learning. Instead of naive random replay, it uses:

1. Domain-aware proportional allocation with floor/cap constraints
2. Per-model diversity limits to avoid over-representing frequent models  
3. Farthest-point sampling in embedding space for maximum diversity

CL Motivation:
- Random replay can undersample long-tail domains/models, leading to 
  catastrophic forgetting on those subpopulations.
- Coreset selection ensures balanced coverage across domains and models,
  while maximizing embedding-space diversity within those constraints.
"""

import os
import json
import hashlib
import random
from pathlib import Path
from typing import List, Dict, Optional, Any, Tuple
from collections import defaultdict

import numpy as np
from sentence_transformers import SentenceTransformer
import torch


# ============================================================================
# Embedding Cache Utilities
# ============================================================================

def get_cache_dir(cache_dir: Optional[str] = None) -> Path:
    """Get the embedding cache directory, creating if necessary."""
    if cache_dir is None:
        # Default to cco/cache/embeddings
        package_root = Path(__file__).parent.parent
        cache_dir = package_root / "cache" / "embeddings"
    else:
        cache_dir = Path(cache_dir)
    
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir


def compute_cache_key(examples: List[Dict[str, Any]], embedding_source: str) -> str:
    """Compute a stable cache key for a set of examples."""
    # Use instruction texts and model names to create a deterministic hash
    key_data = []
    for ex in examples:
        instruction = ex.get("instruction", "")
        model_name = ex.get("model_name", "")
        key_data.append(f"{instruction[:100]}|{model_name}")
    
    combined = f"{embedding_source}:" + "|||".join(sorted(key_data))
    return hashlib.md5(combined.encode()).hexdigest()


def load_cached_embeddings(
    cache_dir: Path, 
    cache_key: str
) -> Optional[Tuple[np.ndarray, Dict[int, int]]]:
    """
    Load cached embeddings if they exist.
    
    Returns:
        Tuple of (embeddings array, index mapping) or None if not cached
    """
    embeddings_path = cache_dir / f"{cache_key}_embeddings.npy"
    mapping_path = cache_dir / f"{cache_key}_mapping.json"
    
    if embeddings_path.exists() and mapping_path.exists():
        try:
            embeddings = np.load(embeddings_path)
            with open(mapping_path, 'r') as f:
                mapping = {int(k): v for k, v in json.load(f).items()}
            return embeddings, mapping
        except Exception as e:
            print(f"Warning: Failed to load cached embeddings: {e}")
            return None
    return None


def save_cached_embeddings(
    cache_dir: Path,
    cache_key: str,
    embeddings: np.ndarray,
    index_mapping: Dict[int, int]
):
    """Save embeddings to cache."""
    embeddings_path = cache_dir / f"{cache_key}_embeddings.npy"
    mapping_path = cache_dir / f"{cache_key}_mapping.json"
    
    try:
        np.save(embeddings_path, embeddings)
        with open(mapping_path, 'w') as f:
            json.dump(index_mapping, f)
    except Exception as e:
        print(f"Warning: Failed to cache embeddings: {e}")


# ============================================================================
# Embedding Computation
# ============================================================================

def compute_embeddings(
    examples: List[Dict[str, Any]],
    embedding_source: str = "sentence_transformer",
    cache_dir: Optional[str] = None,
    batch_size: int = 64,
    device: Optional[str] = None
) -> Tuple[np.ndarray, Dict[int, int]]:
    """
    Compute or load cached embeddings for examples.
    
    Args:
        examples: List of training examples with 'instruction' field
        embedding_source: Which embedding model to use
        cache_dir: Optional cache directory path
        batch_size: Batch size for embedding computation
        device: Device for computation ('cuda' or 'cpu')
    
    Returns:
        Tuple of:
        - embeddings: np.ndarray of shape [num_examples, embedding_dim]
        - index_mapping: Dict mapping original index -> embedding index
    """
    cache_path = get_cache_dir(cache_dir)
    cache_key = compute_cache_key(examples, embedding_source)
    
    # Try loading from cache
    cached = load_cached_embeddings(cache_path, cache_key)
    if cached is not None:
        print(f"  Loaded cached embeddings for {len(cached[0])} examples")
        return cached
    
    # Compute embeddings
    print(f"  Computing embeddings with {embedding_source} for {len(examples)} examples...")
    
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Extract prompts
    prompts = []
    index_mapping = {}
    for i, ex in enumerate(examples):
        prompt = ex.get("instruction", "").strip()
        if prompt:
            index_mapping[i] = len(prompts)
            prompts.append(prompt)
    
    if not prompts:
        return np.array([]), {}
    
    # Compute based on source
    if embedding_source == "sentence_transformer":
        model = SentenceTransformer(
            "all-mpnet-base-v2",
            device=device,
            cache_folder=os.environ.get("HF_HOME", None)
        )
        embeddings = model.encode(
            prompts,
            batch_size=batch_size,
            show_progress_bar=True,
            convert_to_numpy=True
        )
    elif embedding_source == "flagembedding":
        from ..retrievers_carve.bgem3 import BGEM3Retriever
        retriever = BGEM3Retriever(device=device)
        embeddings_dict = retriever.model.encode(
            prompts,
            batch_size=batch_size,
            return_dense=True,
            return_sparse=False,
            return_colbert_vecs=False
        )
        embeddings = embeddings_dict["dense_vecs"]
    else:
        # Default to sentence transformer
        model = SentenceTransformer(
            "all-mpnet-base-v2",
            device=device,
            cache_folder=os.environ.get("HF_HOME", None)
        )
        embeddings = model.encode(
            prompts,
            batch_size=batch_size,
            show_progress_bar=True,
            convert_to_numpy=True
        )
    
    # Cache results
    save_cached_embeddings(cache_path, cache_key, embeddings, index_mapping)
    print(f"  Cached embeddings for future use")
    
    return embeddings, index_mapping


# ============================================================================
# Farthest-Point Sampling
# ============================================================================

def farthest_point_sampling(
    embeddings: np.ndarray,
    indices: List[int],
    k: int,
    seed: Optional[int] = None,
    existing_selected: Optional[List[int]] = None
) -> List[int]:
    """
    Select k points using farthest-point sampling for maximum diversity.
    
    This greedily selects points that are farthest from the already-selected set,
    ensuring good coverage of the embedding space.
    
    Optimized version that:
    - Pre-normalizes candidate embeddings once
    - Uses incremental distance updates instead of full matrix multiplication
    - Maintains selected embeddings as numpy array for efficiency
    
    Args:
        embeddings: Full embeddings array
        indices: Indices of candidate points to select from
        k: Number of points to select
        seed: Random seed for initial point selection
        existing_selected: Optional list of already-selected indices to consider
                          when computing distances (for filling quotas)
    
    Returns:
        List of selected indices
    """
    if len(indices) <= k:
        return list(indices)
    
    if k == 0:
        return []
    
    # Get embeddings for candidates and normalize ONCE (optimization)
    candidate_embeddings = embeddings[indices]
    candidate_norms = np.linalg.norm(candidate_embeddings, axis=1, keepdims=True)
    candidate_norm = candidate_embeddings / (candidate_norms + 1e-8)
    
    # Use set for O(1) lookups (optimization)
    selected_set = set()
    selected = []
    selected_norm_list = []
    
    # If we have existing selections, normalize them once
    if existing_selected and len(existing_selected) > 0:
        existing_emb = embeddings[existing_selected]
        existing_norms = np.linalg.norm(existing_emb, axis=1, keepdims=True)
        selected_norm_list = (existing_emb / (existing_norms + 1e-8)).tolist()
        # Note: existing_selected are embedding indices, not original indices
        # We'll track them separately for distance computation
    
    # Pick random starting point if no existing selections
    if not selected_norm_list:
        if seed is not None:
            random.seed(seed)
        start_idx = random.randint(0, len(indices) - 1)
        selected.append(indices[start_idx])
        selected_set.add(indices[start_idx])
        # Normalize and store
        start_emb = embeddings[indices[start_idx]]
        start_norm = start_emb / (np.linalg.norm(start_emb) + 1e-8)
        selected_norm_list.append(start_norm)
    
    # Maintain max similarities to selected set for each candidate (optimization)
    # This avoids recomputing full similarity matrix each iteration
    # We track max similarity (closest point) - farthest point has minimum max similarity
    n_candidates = len(indices)
    
    # Initialize distances based on current selected set
    if len(selected_norm_list) > 0:
        # Compute similarities to all currently selected points
        selected_norm_array = np.array(selected_norm_list)
        similarities = np.dot(candidate_norm, selected_norm_array.T)  # [n_candidates, n_selected]
        max_similarities = np.max(similarities, axis=1)  # Max similarity to any selected point
    else:
        # Should not happen, but initialize to -inf if no selections yet
        max_similarities = np.full(n_candidates, -np.inf)
    
    # Greedily add farthest points
    while len(selected) < k:
        # Mask already selected candidates
        for i, idx in enumerate(indices):
            if idx in selected_set:
                max_similarities[i] = float('inf')  # Don't re-select
        
        # Find farthest point (minimum max-similarity)
        farthest_candidate_idx = np.argmin(max_similarities)
        farthest_original_idx = indices[farthest_candidate_idx]
        
        selected.append(farthest_original_idx)
        selected_set.add(farthest_original_idx)
        
        # Get normalized embedding of newly selected point
        new_emb = embeddings[farthest_original_idx]
        new_norm = new_emb / (np.linalg.norm(new_emb) + 1e-8)
        selected_norm_list.append(new_norm)
        
        # Incrementally update max similarities: compute similarity to newly added point
        # and update to be max of current and new similarity
        new_similarities = np.dot(candidate_norm, new_norm)  # [n_candidates]
        max_similarities = np.maximum(max_similarities, new_similarities)
        
        # Mask the newly selected point for next iteration
        max_similarities[farthest_candidate_idx] = float('inf')
    
    return selected


# ============================================================================
# Main Coreset Builder
# ============================================================================

def compute_coreset_cache_key(
    examples: List[Dict[str, Any]],
    replay_ratio: float,
    min_per_domain: int,
    max_per_domain: Optional[int],
    max_per_model: int,
    embedding_source: str,
    seed: Optional[int]
) -> str:
    """Compute a stable cache key for coreset selection."""
    # Base key from examples (same as embedding cache)
    example_key = compute_cache_key(examples, embedding_source)
    
    # Add coreset-specific parameters
    params = f"{replay_ratio}_{min_per_domain}_{max_per_domain}_{max_per_model}_{seed}"
    combined = f"coreset_{example_key}_{params}"
    return hashlib.md5(combined.encode()).hexdigest()


def load_cached_coreset(
    cache_dir: Path,
    cache_key: str
) -> Optional[List[int]]:
    """Load cached coreset indices if they exist."""
    coreset_path = cache_dir / f"{cache_key}_coreset.json"
    
    if coreset_path.exists():
        try:
            with open(coreset_path, 'r') as f:
                return json.load(f)
        except Exception as e:
            print(f"Warning: Failed to load cached coreset: {e}")
            return None
    return None


def save_cached_coreset(
    cache_dir: Path,
    cache_key: str,
    indices: List[int]
):
    """Save coreset indices to cache."""
    coreset_path = cache_dir / f"{cache_key}_coreset.json"
    
    try:
        with open(coreset_path, 'w') as f:
            json.dump(indices, f)
    except Exception as e:
        print(f"Warning: Failed to cache coreset: {e}")


def build_domain_model_coreset_replay(
    apibench_examples: List[Dict[str, Any]],
    replay_ratio: float,
    min_per_domain: int,
    max_per_domain: Optional[int],
    max_per_model: int,
    embedding_source: str = "sentence_transformer",
    cache_dir: Optional[str] = None,
    seed: Optional[int] = None
) -> List[Dict[str, Any]]:
    """
    Build a domain + model-aware coreset replay buffer.
    
    Algorithm:
    1. Compute total replay budget B = replay_ratio * len(examples)
    2. Group examples by domain, compute proportional quotas with floor/cap
    3. Within each domain, group by model and select up to max_per_model diverse examples
    4. Use farthest-point sampling in embedding space for diversity
    5. Fill or trim domain quotas as needed
    
    Args:
        apibench_examples: Full list of APIBench training examples
        replay_ratio: Fraction of examples to include in replay (e.g., 0.1)
        min_per_domain: Minimum examples per domain (floor)
        max_per_domain: Optional maximum per domain (cap)
        max_per_model: Maximum examples per model within a domain
        embedding_source: Embedding model to use for diversity sampling
        cache_dir: Optional embedding cache directory
        seed: Random seed for reproducibility
    
    Returns:
        List of selected examples for replay
    """
    if not apibench_examples:
        return []
    
    # Step 0: Check for cached coreset
    cache_path = get_cache_dir(cache_dir)
    coreset_cache_key = compute_coreset_cache_key(
        apibench_examples, replay_ratio, min_per_domain, max_per_domain,
        max_per_model, embedding_source, seed
    )
    
    cached_indices = load_cached_coreset(cache_path, coreset_cache_key)
    if cached_indices is not None:
        print(f"\n=== Loading Cached Coreset Replay ===")
        print(f"  Loaded {len(cached_indices)} cached coreset indices")
        result = [apibench_examples[i] for i in cached_indices if i < len(apibench_examples)]
        print(f"  Final replay size: {len(result)}")
        return result
    
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
    
    # Step 1: Compute total budget
    total_budget = int(replay_ratio * len(apibench_examples))
    print(f"\n=== Building Domain+Model Coreset Replay ===")
    print(f"  Total examples: {len(apibench_examples)}")
    print(f"  Replay ratio: {replay_ratio}")
    print(f"  Target budget: {total_budget}")
    
    # Step 2: Compute embeddings (with caching)
    embeddings, index_mapping = compute_embeddings(
        apibench_examples,
        embedding_source=embedding_source,
        cache_dir=cache_dir
    )
    
    # Create reverse mapping: embedding_idx -> original_idx
    reverse_mapping = {v: k for k, v in index_mapping.items()}
    
    # Step 3: Group by domain
    by_domain: Dict[str, List[int]] = defaultdict(list)
    for i, ex in enumerate(apibench_examples):
        domain = ex.get("domain", "unknown")
        if not domain:
            # Try to get from api_data
            api_data = ex.get("api_data", {})
            if isinstance(api_data, dict):
                domain = api_data.get("domain", "unknown")
        if not domain:
            domain = "unknown"
        by_domain[domain].append(i)
    
    print(f"  Domains found: {len(by_domain)}")
    
    # Step 4: Compute domain quotas
    domain_quotas = {}
    for domain, indices in by_domain.items():
        # Proportional allocation
        raw_quota = total_budget * len(indices) / len(apibench_examples)
        quota = max(min_per_domain, int(round(raw_quota)))
        
        # Apply cap
        if max_per_domain is not None:
            quota = min(quota, max_per_domain, len(indices))
        else:
            quota = min(quota, len(indices))
        
        domain_quotas[domain] = quota
    
    # Step 5: Build coreset per domain
    all_selected = []
    domain_stats = {}
    
    for domain, domain_indices in by_domain.items():
        quota = domain_quotas[domain]
        
        # Group by model within domain
        by_model: Dict[str, List[int]] = defaultdict(list)
        for idx in domain_indices:
            model_id = apibench_examples[idx].get("model_name", "unknown")
            by_model[model_id].append(idx)
        
        # First pass: select up to max_per_model per model using FPS
        domain_coreset = []
        for model_id, model_indices in by_model.items():
            # Get embedding indices for this model's examples
            model_emb_indices = [
                index_mapping[i] for i in model_indices 
                if i in index_mapping
            ]
            
            if not model_emb_indices:
                # No embeddings, just take random sample
                n_select = min(max_per_model, len(model_indices))
                domain_coreset.extend(random.sample(model_indices, n_select))
            else:
                # Use FPS to select diverse examples
                n_select = min(max_per_model, len(model_emb_indices))
                selected_emb_indices = farthest_point_sampling(
                    embeddings,
                    model_emb_indices,
                    n_select,
                    seed=seed
                )
                # Map back to original indices
                for emb_idx in selected_emb_indices:
                    domain_coreset.append(reverse_mapping[emb_idx])
        
        # Adjust to quota
        if len(domain_coreset) > quota:
            # Downsample using FPS on the coreset
            coreset_emb_indices = [
                index_mapping[i] for i in domain_coreset 
                if i in index_mapping
            ]
            if len(coreset_emb_indices) >= quota:
                selected_emb_indices = farthest_point_sampling(
                    embeddings,
                    coreset_emb_indices,
                    quota,
                    seed=seed
                )
                domain_coreset = [reverse_mapping[i] for i in selected_emb_indices]
            else:
                domain_coreset = random.sample(domain_coreset, quota)
        
        elif len(domain_coreset) < quota:
            # Fill remaining slots from unused examples in domain
            used_set = set(domain_coreset)
            remaining = [i for i in domain_indices if i not in used_set]
            
            if remaining:
                needed = quota - len(domain_coreset)
                remaining_emb_indices = [
                    index_mapping[i] for i in remaining 
                    if i in index_mapping
                ]
                
                if len(remaining_emb_indices) >= needed:
                    # Use FPS with existing selections as reference
                    existing_emb = [index_mapping[i] for i in domain_coreset if i in index_mapping]
                    selected_emb_indices = farthest_point_sampling(
                        embeddings,
                        remaining_emb_indices,
                        needed,
                        seed=seed,
                        existing_selected=existing_emb
                    )
                    for emb_idx in selected_emb_indices:
                        domain_coreset.append(reverse_mapping[emb_idx])
                else:
                    # Add all remaining
                    domain_coreset.extend(remaining)
        
        all_selected.extend(domain_coreset)
        
        # Collect model stats for this domain
        model_counts = defaultdict(int)
        for idx in domain_coreset:
            model_id = apibench_examples[idx].get("model_name", "unknown")
            model_counts[model_id] += 1
        
        domain_stats[domain] = {
            "count": len(domain_coreset),
            "quota": quota,
            "total_available": len(domain_indices),
            "unique_models": len(model_counts)
        }
    
    # Step 6: Global adjustment if needed
    if len(all_selected) > total_budget:
        # Trim globally using FPS
        print(f"  Trimming {len(all_selected)} -> {total_budget} globally")
        all_emb_indices = [
            index_mapping[i] for i in all_selected 
            if i in index_mapping
        ]
        if len(all_emb_indices) >= total_budget:
            selected_emb_indices = farthest_point_sampling(
                embeddings,
                all_emb_indices,
                total_budget,
                seed=seed
            )
            all_selected = [reverse_mapping[i] for i in selected_emb_indices]
        else:
            all_selected = random.sample(all_selected, total_budget)
    
    # Remove duplicates while preserving order
    seen = set()
    unique_selected = []
    for idx in all_selected:
        if idx not in seen:
            seen.add(idx)
            unique_selected.append(idx)
    
    # Build final result
    result = [apibench_examples[i] for i in unique_selected]
    
    # Cache the coreset indices for future runs
    save_cached_coreset(cache_path, coreset_cache_key, unique_selected)
    print(f"  Cached coreset selection for future use")
    
    # Log statistics
    print(f"\n  Final replay size: {len(result)}")
    print(f"\n  Per-domain statistics:")
    for domain, stats in sorted(domain_stats.items(), key=lambda x: -x[1]["count"]):
        print(f"    {domain}: {stats['count']} selected (quota={stats['quota']}, "
              f"available={stats['total_available']}, unique_models={stats['unique_models']})")
    
    # Log top models
    model_counts = defaultdict(int)
    for ex in result:
        model_counts[ex.get("model_name", "unknown")] += 1
    
    print(f"\n  Top-10 models by replay frequency:")
    for model, count in sorted(model_counts.items(), key=lambda x: -x[1])[:10]:
        print(f"    {model}: {count}")
    
    return result

