"""
Embedding Dispersion Analysis Utility

This module provides functions to calculate the dispersion of embeddings,
defined as the average pairwise cosine distance:

    Dispersion = 2/(N(N-1)) * sum_{i<j} (1 - cos(x_i, x_j))

Higher dispersion indicates embeddings are more spread out in the space.
"""

import numpy as np
import torch
from typing import Union, Dict, List, Optional
from sklearn.metrics.pairwise import cosine_similarity
import os
import sys

# Handle imports for both module usage and direct script execution
# Add parent directory to path to allow imports when running as script
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from loguru import logger
from utils.load_data import load_npy


def compute_dispersion(
    embeddings: Union[np.ndarray, torch.Tensor],
    batch_size: Optional[int] = None,
    use_gpu: bool = False
) -> float:
    """
    Compute the dispersion of embeddings as average pairwise cosine distance.

    Args:
        embeddings: Embedding matrix of shape (N, d) where N is number of samples
                   and d is the embedding dimension
        batch_size: If specified, compute in batches to reduce memory usage.
                   Useful for large N. If None, compute all at once.
        use_gpu: Whether to use GPU for computation (requires torch)

    Returns:
        dispersion: Float value representing the average pairwise cosine distance

    Formula:
        Dispersion = 2/(N(N-1)) * sum_{i<j} (1 - cos(x_i, x_j))

    Example:
        >>> emb = np.random.randn(1000, 512)
        >>> disp = compute_dispersion(emb)
        >>> print(f"Dispersion: {disp:.4f}")
    """
    # Convert to numpy if torch tensor
    if isinstance(embeddings, torch.Tensor):
        if use_gpu and torch.cuda.is_available():
            embeddings_tensor = embeddings.cuda()
        else:
            embeddings = embeddings.cpu().numpy()

    N = len(embeddings)

    if N < 2:
        logger.warning(f"Cannot compute dispersion with N={N} < 2")
        return 0.0

    # Compute using GPU with torch
    if use_gpu and torch.cuda.is_available():
        if not isinstance(embeddings, torch.Tensor):
            embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32).cuda()
        else:
            embeddings_tensor = embeddings_tensor.float()

        # Normalize embeddings for cosine similarity
        embeddings_norm = torch.nn.functional.normalize(embeddings_tensor, p=2, dim=1)

        # Compute cosine similarity matrix
        if batch_size is not None:
            # Batch computation to save memory
            total_distance = 0.0
            for i in range(0, N, batch_size):
                batch_end = min(i + batch_size, N)
                batch = embeddings_norm[i:batch_end]

                # Compute similarity for this batch against all embeddings
                sim_batch = torch.mm(batch, embeddings_norm.t())  # (batch_size, N)

                # Only sum upper triangle to avoid double counting
                for local_idx in range(len(batch)):
                    global_idx = i + local_idx
                    # Sum from global_idx+1 to N (upper triangle)
                    if global_idx + 1 < N:
                        total_distance += torch.sum(1 - sim_batch[local_idx, global_idx+1:]).item()
        else:
            # Compute all at once
            sim_matrix = torch.mm(embeddings_norm, embeddings_norm.t())  # (N, N)

            # Extract upper triangle (excluding diagonal)
            # Use triu with offset=1 to get strictly upper triangular part
            upper_tri_mask = torch.triu(torch.ones_like(sim_matrix, dtype=torch.bool), diagonal=1)
            cosine_sims = sim_matrix[upper_tri_mask]

            # Convert to distances and sum
            total_distance = torch.sum(1 - cosine_sims).item()

        dispersion = (2.0 / (N * (N - 1))) * total_distance

    else:
        # Compute using CPU with sklearn
        if batch_size is not None:
            # Batch computation
            total_distance = 0.0
            for i in range(0, N, batch_size):
                batch_end = min(i + batch_size, N)
                batch = embeddings[i:batch_end]

                # Compute similarity for this batch against all embeddings
                sim_batch = cosine_similarity(batch, embeddings)  # (batch_size, N)

                # Only sum upper triangle to avoid double counting
                for local_idx in range(len(batch)):
                    global_idx = i + local_idx
                    # Sum from global_idx+1 to N (upper triangle)
                    if global_idx + 1 < N:
                        total_distance += np.sum(1 - sim_batch[local_idx, global_idx+1:])
        else:
            # Compute all at once
            sim_matrix = cosine_similarity(embeddings)  # (N, N)

            # Extract upper triangle (excluding diagonal)
            upper_tri_indices = np.triu_indices(N, k=1)
            cosine_sims = sim_matrix[upper_tri_indices]

            # Convert to distances and sum
            total_distance = np.sum(1 - cosine_sims)

        dispersion = (2.0 / (N * (N - 1))) * total_distance

    return float(dispersion)


def compute_dataset_dispersion(
    data_dir: str,
    model_name: str,
    dataset_names: Optional[List[str]] = None,
    batch_size: Optional[int] = 1000,
    use_gpu: bool = False
) -> Dict[str, float]:
    """
    Compute dispersion for all datasets for a given model.

    Args:
        data_dir: Directory containing embedding .npy files
        model_name: Name of the embedding model (used to identify files)
        dataset_names: List of dataset names to analyze. If None, will try to
                      find all available datasets in data_dir
        batch_size: Batch size for dispersion computation
        use_gpu: Whether to use GPU for computation

    Returns:
        results: Dictionary mapping dataset names to dispersion values

    Example:
        >>> results = compute_dataset_dispersion(
        ...     data_dir="embeddings/",
        ...     model_name="sentence-transformers",
        ...     dataset_names=["arxiv", "big_patent"]
        ... )
        >>> for dataset, disp in results.items():
        ...     print(f"{dataset}: {disp:.4f}")
    """
    results = {}

    # If dataset_names not provided, try to find all .npy files
    if dataset_names is None:
        logger.debug(f"Scanning {data_dir} for embedding files with model {model_name}")
        dataset_names_set = set()
        if os.path.exists(data_dir):
            for filename in os.listdir(data_dir):
                if filename.endswith('.npy') and model_name in filename:
                    # Extract dataset name from filename
                    # Handle patterns like: corpus_embeddings_{model}_{dataset}.npy
                    name_without_ext = filename.replace('.npy', '')

                    # Try to extract dataset name after the last underscore
                    parts = name_without_ext.split('_')

                    # Find where model_name appears in parts
                    for i, part in enumerate(parts):
                        if model_name in part or part == model_name:
                            # Dataset name is everything after this part
                            if i + 1 < len(parts):
                                dataset_name = '_'.join(parts[i+1:])
                                # Store with prefix to identify type (corpus/query)
                                prefix = '_'.join(parts[:i+1]) if i > 0 else ''
                                if prefix:
                                    full_name = f"{prefix}_{dataset_name}"
                                else:
                                    full_name = dataset_name
                                dataset_names_set.add(full_name)
                                break

        dataset_names = sorted(list(dataset_names_set))

        if not dataset_names:
            logger.warning(f"No datasets found in {data_dir} for model {model_name}")
            return results

    logger.debug(f"Computing dispersion for {len(dataset_names)} datasets with model {model_name}")

    for dataset_name in dataset_names:
        try:
            # Try different file naming conventions
            embeddings = None

            # Build possible file paths based on common patterns
            # Pattern 1: {prefix}_{model}_{dataset}.npy (e.g., corpus_embeddings_openai_arguana.npy)
            # Pattern 2: {dataset}_{model}.npy
            # Pattern 3: {model}_{dataset}.npy
            # Pattern 4: Just the dataset name

            possible_keys = []

            # If dataset_name already includes prefix (e.g., "corpus_embeddings_arguana")
            if '_' in dataset_name:
                parts = dataset_name.rsplit('_', 1)
                if len(parts) == 2:
                    prefix_with_last, last_part = parts
                    # Try inserting model name before the last part
                    possible_keys.append(f"{prefix_with_last}_{model_name}_{last_part}")

            # Standard patterns
            possible_keys.extend([
                dataset_name,  # Try exact match first
                f"{dataset_name}_{model_name}",
                f"{model_name}_{dataset_name}",
            ])

            for key in possible_keys:
                embeddings = load_npy(data_dir, key)
                if embeddings is not None:
                    logger.debug(f"Loaded {dataset_name} embeddings from {key}.npy")
                    break

            if embeddings is None:
                logger.warning(f"Could not load embeddings for dataset {dataset_name}")
                continue

            # Compute dispersion
            logger.debug(f"Computing dispersion for {dataset_name} (N={len(embeddings)}, d={embeddings.shape[1]})")
            dispersion = compute_dispersion(embeddings, batch_size=batch_size, use_gpu=use_gpu)
            results[dataset_name] = dispersion
            logger.debug(f"{dataset_name}: dispersion = {dispersion:.6f}")

        except Exception as e:
            logger.error(f"Error computing dispersion for {dataset_name}: {e}")
            continue

    return results


def compare_model_dispersions(
    data_dir: str,
    model_names: List[str],
    dataset_names: Optional[List[str]] = None,
    batch_size: Optional[int] = 1000,
    use_gpu: bool = False,
    output_csv: Optional[str] = None
) -> Dict[str, Dict[str, float]]:
    """
    Compare dispersion across multiple models and datasets.

    Args:
        data_dir: Directory containing embedding .npy files
        model_names: List of model names to compare
        dataset_names: List of dataset names to analyze
        batch_size: Batch size for dispersion computation
        use_gpu: Whether to use GPU for computation
        output_csv: If provided, save results to CSV file

    Returns:
        results: Nested dictionary {model_name: {dataset_name: dispersion}}

    Example:
        >>> results = compare_model_dispersions(
        ...     data_dir="embeddings/",
        ...     model_names=["bert", "sentence-transformers", "openai"],
        ...     dataset_names=["arxiv", "big_patent"]
        ... )
    """
    results = {}

    for model_name in model_names:
        logger.debug(f"\n{'='*60}")
        logger.debug(f"Processing model: {model_name}")
        logger.debug(f"{'='*60}")

        model_results = compute_dataset_dispersion(
            data_dir=data_dir,
            model_name=model_name,
            dataset_names=dataset_names,
            batch_size=batch_size,
            use_gpu=use_gpu
        )
        results[model_name] = model_results

    # Print summary table
    logger.debug("\n" + "="*60)
    logger.debug("DISPERSION SUMMARY")
    logger.debug("="*60)

    # Get all unique datasets
    all_datasets = set()
    for model_results in results.values():
        all_datasets.update(model_results.keys())
    all_datasets = sorted(all_datasets)

    # Print header
    header = f"{'Dataset':<20} " + " ".join([f"{m:<15}" for m in model_names])
    logger.debug(header)
    logger.debug("-" * len(header))

    # Print rows
    for dataset in all_datasets:
        row = f"{dataset:<20}"
        for model in model_names:
            disp = results.get(model, {}).get(dataset, None)
            if disp is not None:
                row += f" {disp:<15.6f}"
            else:
                row += f" {'N/A':<15}"
        logger.debug(row)

    # Save to CSV if requested
    if output_csv:
        import csv
        with open(output_csv, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['dataset'] + model_names)
            for dataset in all_datasets:
                row = [dataset]
                for model in model_names:
                    disp = results.get(model, {}).get(dataset, None)
                    row.append(disp if disp is not None else '')
                writer.writerow(row)
        logger.debug(f"\nResults saved to {output_csv}")

    return results


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Compute embedding dispersion metrics")
    parser.add_argument("--data_dir", type=str, required=True,
                       help="Directory containing embedding .npy files")
    parser.add_argument("--model", type=str, default=None,
                       help="Model name (for single model analysis)")
    parser.add_argument("--models", type=str, nargs='+', default=None,
                       help="Multiple model names (for comparison)")
    parser.add_argument("--datasets", type=str, nargs='+', default=None,
                       help="Dataset names to analyze (if None, analyze all)")
    parser.add_argument("--batch_size", type=int, default=1000,
                       help="Batch size for computation")
    parser.add_argument("--use_gpu", action="store_true",
                       help="Use GPU for computation")
    parser.add_argument("--output_csv", type=str, default=None,
                       help="Output CSV file for results")

    args = parser.parse_args()

    if args.models:
        # Compare multiple models
        results = compare_model_dispersions(
            data_dir=args.data_dir,
            model_names=args.models,
            dataset_names=args.datasets,
            batch_size=args.batch_size,
            use_gpu=args.use_gpu,
            output_csv=args.output_csv
        )
    elif args.model:
        # Analyze single model
        results = compute_dataset_dispersion(
            data_dir=args.data_dir,
            model_name=args.model,
            dataset_names=args.datasets,
            batch_size=args.batch_size,
            use_gpu=args.use_gpu
        )

        # Print results
        logger.debug("\nDispersion Results:")
        for dataset, disp in results.items():
            logger.debug(f"  {dataset}: {disp:.6f}")

        # Save to CSV if requested
        if args.output_csv:
            import csv
            with open(args.output_csv, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(['dataset', 'dispersion'])
                for dataset, disp in results.items():
                    writer.writerow([dataset, disp])
            logger.debug(f"Results saved to {args.output_csv}")
    else:
        parser.error("Either --model or --models must be specified")
