import torch
import numpy as np
from typing import List, Dict, Any, Union, Optional, Tuple
import os
import json
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def reduce_dimensions(embeddings: List[Union[torch.Tensor, np.ndarray]], 
                     method: str = 'pca', 
                     n_components: int = 2) -> np.ndarray:
    """
    Reduce dimensionality of embeddings for visualization.
    
    Args:
        embeddings: List of embedding vectors
        method: Reduction method ('pca' or 'tsne')
        n_components: Number of components in output
        
    Returns:
        Reduced embeddings as numpy array
    """
    # Convert to numpy arrays and stack
    numpy_embeddings = []
    for emb in embeddings:
        if isinstance(emb, torch.Tensor):
            numpy_embeddings.append(emb.detach().cpu().numpy())
        else:
            numpy_embeddings.append(emb)
    
    stacked = np.stack(numpy_embeddings)
    
    # Apply dimensionality reduction
    if method.lower() == 'pca':
        reducer = PCA(n_components=n_components)
    elif method.lower() == 'tsne':
        reducer = TSNE(n_components=n_components, perplexity=min(30, len(embeddings) - 1))
    else:
        raise ValueError(f"Unknown reduction method: {method}")
        
    return reducer.fit_transform(stacked)

def visualize_embeddings(embeddings: List[Union[torch.Tensor, np.ndarray]],
                       labels: Optional[List[str]] = None,
                       colors: Optional[List[str]] = None,
                       method: str = 'pca',
                       title: str = 'Embedding Visualization',
                       save_path: Optional[str] = None) -> None:
    """
    Visualize embeddings in 2D space.
    
    Args:
        embeddings: List of embedding vectors
        labels: Optional list of text labels
        colors: Optional list of colors for points
        method: Reduction method ('pca' or 'tsne')
        title: Plot title
        save_path: Optional path to save the visualization
    """
    # Reduce dimensions
    reduced = reduce_dimensions(embeddings, method, n_components=2)
    
    # Set up plot
    plt.figure(figsize=(10, 8))
    
    # Determine colors
    if colors is None:
        scatter = plt.scatter(reduced[:, 0], reduced[:, 1])
    else:
        scatter = plt.scatter(reduced[:, 0], reduced[:, 1], c=colors)
        plt.colorbar(scatter)
    
    # Add labels if provided
    if labels is not None:
        for i, label in enumerate(labels):
            plt.annotate(label, (reduced[i, 0], reduced[i, 1]), 
                        fontsize=9, alpha=0.7)
    
    plt.title(f'{title} ({method.upper()})')
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.tight_layout()
    
    # Save or show
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()
        
def interpolate_embeddings(emb1: Union[torch.Tensor, np.ndarray],
                          emb2: Union[torch.Tensor, np.ndarray],
                          steps: int = 10) -> List[Union[torch.Tensor, np.ndarray]]:
    """
    Create a smooth interpolation between two embeddings.
    
    Args:
        emb1: First embedding
        emb2: Second embedding
        steps: Number of interpolation steps
        
    Returns:
        List of interpolated embeddings
    """
    # Determine output type
    output_type = type(emb1)
    
    # Convert to numpy for consistent processing
    if isinstance(emb1, torch.Tensor):
        emb1_np = emb1.detach().cpu().numpy()
    else:
        emb1_np = emb1
        
    if isinstance(emb2, torch.Tensor):
        emb2_np = emb2.detach().cpu().numpy()
    else:
        emb2_np = emb2
    
    # Create interpolations
    alphas = np.linspace(0, 1, steps)
    interpolations_np = []
    
    for alpha in alphas:
        interp = (1 - alpha) * emb1_np + alpha * emb2_np
        interpolations_np.append(interp)
    
    # Convert back to original type if needed
    if output_type == torch.Tensor:
        return [torch.tensor(interp) for interp in interpolations_np]
    else:
        return interpolations_np

def compute_centroid(embeddings: List[Union[torch.Tensor, np.ndarray]]) -> Union[torch.Tensor, np.ndarray]:
    """
    Compute the centroid (average) of a set of embeddings.
    
    Args:
        embeddings: List of embedding vectors
        
    Returns:
        Centroid embedding
    """
    if not embeddings:
        raise ValueError("Cannot compute centroid of empty embedding list")
        
    # Determine output type
    output_type = type(embeddings[0])
    
    # Convert to numpy for consistent processing
    numpy_embeddings = []
    for emb in embeddings:
        if isinstance(emb, torch.Tensor):
            numpy_embeddings.append(emb.detach().cpu().numpy())
        else:
            numpy_embeddings.append(emb)
    
    # Compute average
    centroid = np.mean(numpy_embeddings, axis=0)
    
    # Convert back to original type if needed
    if output_type == torch.Tensor:
        return torch.tensor(centroid)
    else:
        return centroid

def compute_embedding_stats(embeddings: List[Union[torch.Tensor, np.ndarray]]) -> Dict[str, Any]:
    """
    Compute statistics for a set of embeddings.
    
    Args:
        embeddings: List of embedding vectors
        
    Returns:
        Dictionary of statistics
    """
    if not embeddings:
        raise ValueError("Cannot compute stats of empty embedding list")
        
    # Convert to numpy for consistent processing
    numpy_embeddings = []
    for emb in embeddings:
        if isinstance(emb, torch.Tensor):
            numpy_embeddings.append(emb.detach().cpu().numpy())
        else:
            numpy_embeddings.append(emb)
    
    stacked = np.stack(numpy_embeddings)
    
    # Compute statistics
    return {
        "mean": np.mean(stacked, axis=0),
        "std": np.std(stacked, axis=0),
        "min": np.min(stacked, axis=0),
        "max": np.max(stacked, axis=0),
        "norm_mean": np.mean(np.linalg.norm(stacked, axis=1)),
        "norm_std": np.std(np.linalg.norm(stacked, axis=1))
    } 