import torch
import json
import pandas as pd
import numpy as np
import re
import argparse
from typing import List, Dict, Any, Tuple, Optional
import csv
from pathlib import Path
import random

def normalized_laplacian(W: torch.Tensor, form: str = 'sym') -> torch.Tensor:
    """Compute the normalized Laplacian from a weighted adjacency matrix.
    
    Args:
        W: (n x n), symmetric weighted adjacency matrix
        form: Form of the Laplacian, either 'sym' (symmetric) or 'rw' (random walk)
        
    Returns:
        The normalized Laplacian matrix
    """
    # W: (n x n), symmetric weighted adjacency matrix
    deg = W.sum(dim=1)  # Degree vector

    if form == 'sym':
        # Avoid division by zero
        deg_inv_sqrt = torch.pow(deg, -0.5)
        deg_inv_sqrt[deg == 0] = 0.0
        D_inv_sqrt = torch.diag(deg_inv_sqrt)
        L = torch.eye(W.shape[0], device=W.device) - D_inv_sqrt @ W @ D_inv_sqrt
    elif form == 'rw':
        deg_inv = torch.pow(deg, -1.0)
        deg_inv[deg == 0] = 0.0
        D_inv = torch.diag(deg_inv)
        L = torch.eye(W.shape[0], device=W.device) - D_inv @ W
    else:
        raise ValueError("form must be 'sym' or 'rw'")

    return L

def effective_rank(A, eps=1e-12):
    """Compute the effective rank of a matrix.
    
    Args:
        A: Input matrix
        eps: Small value to avoid numerical issues
        
    Returns:
        The effective rank
    """
    # Compute singular values
    U, S, Vh = torch.linalg.svd(A, full_matrices=False)
    
    # Normalize to get probabilities
    p = S / (S.sum() + eps)

    # Compute entropy
    entropy = -torch.sum(p * torch.log(p + eps))
    
    # Effective rank
    return torch.exp(entropy)

def load_embeddings(embeddings_path: str) -> Dict[str, torch.Tensor]:
    """Load embedding tensors from a pytorch file.
    
    Args:
        embeddings_path: Path to the pytorch file containing embeddings
        
    Returns:
        Dict of {key: embedding_tensor}
    """
    return torch.load(embeddings_path)

def load_auto_interp(auto_interp_path: str) -> Dict[str, Dict[str, Any]]:
    """Load auto interpretation scores from a JSON file.
    
    Args:
        auto_interp_path: Path to the JSON file with interpretation data
        
    Returns:
        Dict mapping keys to fluency and consistency scores
    """
    with open(auto_interp_path, 'r') as f:
        data = json.load(f)
    
    # Create a dictionary mapping keys to their scores
    scores = {}
    for item_idx, item in enumerate(data):
        key = item['key']
        
        try:
            # Try to parse the content as JSON
            content_str = item['content']
            # Handle potential JSON-in-JSON structure - sometimes content is already a JSON object
            if isinstance(content_str, dict):
                content = content_str
            else:
                # Clean up the string to handle potential escaped quotes and newlines
                content_str = content_str.replace('\n', '\\n').replace('\\\\n', '\\n')
                try:
                    content = json.loads(content_str)
                except json.JSONDecodeError:
                    # Try to extract fluency and consistency directly using regex
                    fluency_match = re.search(r'"fluency_score":\s*(\d+)', content_str)
                    consistency_match = re.search(r'"consistency_score":\s*(\d+)', content_str)
                    theme_match = re.search(r'"theme":\s*"([^"]*)"', content_str)
                    
                    content = {
                        'fluency_score': int(fluency_match.group(1)) if fluency_match else 0,
                        'consistency_score': int(consistency_match.group(1)) if consistency_match else 0,
                        'theme': theme_match.group(1) if theme_match else ''
                    }
            
            fluency_score = content.get('fluency_score', 0)
            consistency_score = content.get('consistency_score', 0)
            theme = content.get('theme', '')
            
            # Apply filtering criteria for consistency score
            if (not re.search(r'standard ai assistant', str(theme), re.IGNORECASE) and
                not re.search(r'storytelling', str(theme), re.IGNORECASE) and
                not re.search(r'repetitive', str(theme), re.IGNORECASE) and
                not re.search(r'incomplete', str(theme), re.IGNORECASE) and
                not re.search(r'NONE', str(theme)) and
                not re.search(r'repetiti', str(theme), re.IGNORECASE) and
                not re.search(r'echo', str(theme), re.IGNORECASE) and
                not re.search(r'punctuation', str(theme), re.IGNORECASE) and
                not re.search(r'diverse persona generation', str(theme), re.IGNORECASE) and
                not re.search(r'neutral', str(theme), re.IGNORECASE) and
                (not str(theme).startswith('AI assistant persona') if theme is not None else False) and
                theme is not None):
                # Only keep valid theme scores
                scores[key] = {
                    'fluency_score': fluency_score,
                    'consistency_score': consistency_score,
                    'theme': theme
                }
            else:
                # For filtered out keys, set consistency score to 0
                scores[key] = {
                    'fluency_score': fluency_score,
                    'consistency_score': 0,  # Set to 0 for false positives
                    'theme': theme
                }
        except Exception as e:
            print(f"Error processing item {item_idx} with key {key}: {e}")
            print(f"Content sample: {item.get('content', '')[:100]}...")
            # Set default values for keys with parsing errors
            scores[key] = {
                'fluency_score': 0,
                'consistency_score': 0,
                'theme': ''
            }
    
    print(f"Successfully processed {len(scores)} out of {len(data)} items")
    return scores

def compute_gaussian_kernel(X: torch.Tensor, length_scale: float) -> torch.Tensor:
    """Compute Gaussian kernel matrix from embeddings.
    
    Args:
        X: Normalized embedding matrix, shape (n, d)
        length_scale: Length scale for the Gaussian kernel
        
    Returns:
        Gaussian kernel matrix of shape (n, n)
    """
    # Compute pairwise squared Euclidean distances
    n = X.shape[0]
    dot_product = torch.mm(X, X.t())
    
    # Squared Euclidean distance matrix
    # ||x_i - x_j||^2 = ||x_i||^2 + ||x_j||^2 - 2 <x_i, x_j>
    # Since X is normalized, ||x_i||^2 = ||x_j||^2 = 1
    sq_dists = 2 - 2 * dot_product
    
    # Avoid negative values due to numerical errors
    sq_dists = torch.clamp(sq_dists, min=0.0)
    
    # Compute Gaussian kernel
    gamma = 1.0 / (2.0 * length_scale**2)
    kernel = torch.exp(-gamma * sq_dists)
    
    return kernel

def analyze_single_response(embeddings: Dict[str, torch.Tensor], 
                        scores: Optional[Dict[str, Dict[str, Any]]],
                        response_idx: int,
                        sample_sizes: List[int],
                        length_scales: List[float],
                        n_samples: int) -> pd.DataFrame:
    """Perform analysis 1: analyze clusters in a specific response index.
    
    Args:
        embeddings: Dict of embeddings
        scores: Dict of interpretation scores (or None if no scores are available)
        response_idx: Index of the response to analyze
        sample_sizes: List of sample sizes to try
        length_scales: List of length scales for Gaussian kernel
        n_samples: Number of subsamples to draw
        
    Returns:
        DataFrame with analysis results
    """
    print(f"\nAnalyzing clusters at response index {response_idx}")
    
    # Extract all valid embeddings for the specific response index
    valid_keys = []
    X = {}
    
    for key, emb in embeddings.items():
        if response_idx < emb.shape[0]:
            valid_keys.append(key)
            X[key] = emb[response_idx].clone()
        else:
            print(f"Warning: response_idx {response_idx} out of bounds for key {key}. Skipping.")
    
    print(f"Found {len(valid_keys)} keys with valid embeddings for response index {response_idx}")
    
    if len(X) == 0:
        print(f"No valid embeddings found for response index {response_idx}.")
        return pd.DataFrame()
    
    # Normalize embeddings
    for key in X:
        X[key] = X[key] / torch.norm(X[key])
    
    # Prepare results
    results = []
    
    # Iterate over sample sizes and length scales
    for sample_size in sample_sizes:
        if sample_size > len(valid_keys):
            print(f"Warning: sample size {sample_size} larger than available keys ({len(valid_keys)}). Using all keys.")
            sample_size = len(valid_keys)
        
        for length_scale in length_scales:
            print(f"Processing sample size: {sample_size}, length scale: {length_scale}")
            
            # Run multiple samples
            sample_results = []
            actual_sizes = []
            for i in range(n_samples):
                # First, sample keys with replacement
                all_sampled_keys = random.choices(valid_keys, k=sample_size)
                
                # If scores are available, filter sampled keys by fluency
                if scores is not None:
                    # Only keep keys with fluency_score >= 9
                    fluent_keys = [key for key in all_sampled_keys if key in scores and scores[key]['fluency_score'] >= 9]
                    num_fluent = len(fluent_keys)
                    
                    if num_fluent == 0:
                        print(f"Warning: Sample {i} has no keys meeting fluency criteria.")
                        continue
                    
                    sampled_keys = fluent_keys
                else:
                    # If no scores, use all sampled keys
                    sampled_keys = all_sampled_keys
                
                # Skip empty samples
                if len(sampled_keys) == 0:
                    continue
                
                # Stack embeddings into a matrix
                emb_matrix = torch.stack([X[key] for key in sampled_keys])
                
                # Compute Gaussian kernel
                kernel = compute_gaussian_kernel(emb_matrix, length_scale)
                
                # Compute normalized Laplacian
                lap = normalized_laplacian(kernel, form='sym')
                
                # Compute effective rank and soft number of clusters
                eff_rank = effective_rank(lap)
                soft_num_clusters = len(sampled_keys) - eff_rank.item()
                
                sample_results.append(soft_num_clusters)
                actual_sizes.append(len(sampled_keys))
            
            # Skip length_scale if all samples were skipped
            if not sample_results:
                print(f"Skipping length_scale {length_scale} for sample_size {sample_size} - no valid samples")
                continue
            
            # Average results across samples
            avg_soft_num_clusters = np.mean(sample_results)
            avg_actual_size = np.mean(actual_sizes)
            
            results.append({
                'analysis_type': 'single_response',
                'response_idx': response_idx,
                'sample_size': sample_size,
                'actual_avg_sample_size': avg_actual_size,
                'length_scale': length_scale,
                'soft_num_clusters': avg_soft_num_clusters
            })
    
    return pd.DataFrame(results)

def analyze_averaged_responses(embeddings: Dict[str, torch.Tensor], 
                             scores: Optional[Dict[str, Dict[str, Any]]],
                             sample_sizes: List[int],
                             length_scales: List[float],
                             n_samples: int) -> pd.DataFrame:
    """Perform analysis 2: analyze with embeddings averaged across all response indices.
    
    Args:
        embeddings: Dict of embeddings
        scores: Dict of interpretation scores (or None if no scores are available)
        sample_sizes: List of sample sizes to try
        length_scales: List of length scales for Gaussian kernel
        n_samples: Number of subsamples to draw
        
    Returns:
        DataFrame with analysis results
    """
    print("\nAnalyzing clusters with embeddings averaged across all response indices")
    
    # Average embeddings across all response indices for each key
    avg_embeddings = {}
    valid_keys = []
    
    for key, emb in embeddings.items():
        # Average across all response indices
        avg_emb = torch.mean(emb, dim=0)
        # Normalize
        norm_emb = avg_emb / torch.norm(avg_emb)
        avg_embeddings[key] = norm_emb
        valid_keys.append(key)
    
    print(f"Created averaged embeddings for {len(valid_keys)} keys")
    
    if len(avg_embeddings) == 0:
        print("No valid embeddings for averaging.")
        return pd.DataFrame()
    
    # Prepare results
    results = []
    
    # Iterate over sample sizes and length scales
    for sample_size in sample_sizes:
        if sample_size > len(valid_keys):
            print(f"Warning: sample size {sample_size} larger than available keys ({len(valid_keys)}). Using all keys.")
            sample_size = len(valid_keys)
        
        for length_scale in length_scales:
            print(f"Processing sample size: {sample_size}, length scale: {length_scale}")
            
            # Run multiple samples
            sample_results = []
            actual_sizes = []
            for i in range(n_samples):
                # First, sample keys with replacement
                all_sampled_keys = random.choices(valid_keys, k=sample_size)
                
                # If scores are available, filter sampled keys by fluency
                if scores is not None:
                    # Only keep keys with fluency_score >= 9
                    fluent_keys = [key for key in all_sampled_keys if key in scores and scores[key]['fluency_score'] >= 9]
                    num_fluent = len(fluent_keys)
                    
                    if num_fluent == 0:
                        print(f"Warning: Sample {i} has no keys meeting fluency criteria.")
                        continue
                    
                    sampled_keys = fluent_keys
                else:
                    # If no scores, use all sampled keys
                    sampled_keys = all_sampled_keys
                
                # Skip empty samples
                if len(sampled_keys) == 0:
                    continue
                
                # Stack embeddings into a matrix
                emb_matrix = torch.stack([avg_embeddings[key] for key in sampled_keys])
                
                # Compute Gaussian kernel
                kernel = compute_gaussian_kernel(emb_matrix, length_scale)
                
                # Compute normalized Laplacian
                lap = normalized_laplacian(kernel, form='sym')
                
                # Compute effective rank and soft number of clusters
                eff_rank = effective_rank(lap)
                soft_num_clusters = len(sampled_keys) - eff_rank.item()
                
                sample_results.append(soft_num_clusters)
                actual_sizes.append(len(sampled_keys))
            
            # Skip length_scale if all samples were skipped
            if not sample_results:
                print(f"Skipping length_scale {length_scale} for sample_size {sample_size} - no valid samples")
                continue
            
            # Average results across samples
            avg_soft_num_clusters = np.mean(sample_results)
            avg_actual_size = np.mean(actual_sizes)
            
            results.append({
                'analysis_type': 'averaged_responses',
                'sample_size': sample_size,
                'actual_avg_sample_size': avg_actual_size,
                'length_scale': length_scale,
                'soft_num_clusters': avg_soft_num_clusters
            })
    
    return pd.DataFrame(results)

def main():
    parser = argparse.ArgumentParser(description='Analyze diversity of LLM responses.')
    parser.add_argument('--embeddings_db', type=str, required=True, help='Path to embeddings database')
    parser.add_argument('--auto_interp', type=str, required=False, help='Path to auto interpretation JSON file')
    parser.add_argument('--response_idx', type=int, default=0, help='Response index to analyze')
    parser.add_argument('--sample_sizes', type=str, default='10,20,30,40,50', help='Comma-separated list of sample sizes')
    parser.add_argument('--length_scales', type=str, default='0.1,0.2,0.5,1.0', help='Comma-separated list of length scales')
    parser.add_argument('--n_samples', type=int, default=10, help='Number of subsamples to draw')
    parser.add_argument('--output_dir', type=str, default='results', help='Output directory')
    parser.add_argument('--scale0_only', action='store_true', help='Filter to only include scale0 keys')
    
    args = parser.parse_args()
    
    # Parse sample sizes and length scales
    sample_sizes = [int(s) for s in args.sample_sizes.split(',')]
    length_scales = [float(s) for s in args.length_scales.split(',')]
    
    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Loading embeddings from {args.embeddings_db}...")
    embeddings = load_embeddings(args.embeddings_db)
    print(f"Loaded {len(embeddings)} embedding sets")
    
    # Filter to scale0 if requested
    if args.scale0_only:
        orig_count = len(embeddings)
        embeddings = {k: v for k, v in embeddings.items() if 'scale0' in k}
        print(f"Filtered to {len(embeddings)} scale0 keys out of {orig_count} total keys")
    
    # Load auto interpretation scores if provided
    scores = None
    if args.auto_interp:
        print(f"Loading auto interpretation from {args.auto_interp}...")
        scores = load_auto_interp(args.auto_interp)
        print(f"Loaded scores for {len(scores)} keys")
    else:
        print("No auto_interp provided. Running analysis without fluency filtering.")
    
    # Perform analysis 1: single response index
    print("\nPerforming Analysis 1: Clusters in a specific response index...")
    analysis1_results = analyze_single_response(
        embeddings=embeddings,
        scores=scores,
        response_idx=args.response_idx,
        sample_sizes=sample_sizes,
        length_scales=length_scales,
        n_samples=args.n_samples
    )
    
    if not analysis1_results.empty:
        # Add scale0 info to filename if filtered
        scale_suffix = '_scale0' if args.scale0_only else ''
        analysis1_path = output_dir / f'analysis1_results_idx{args.response_idx}{scale_suffix}.csv'
        analysis1_results.to_csv(analysis1_path, index=False)
        print(f"Analysis 1 results saved to {analysis1_path}")
    
    # Perform analysis 2: averaged responses
    print("\nPerforming Analysis 2: Clusters with averaged embeddings across all response indices...")
    analysis2_results = analyze_averaged_responses(
        embeddings=embeddings,
        scores=scores,
        sample_sizes=sample_sizes,
        length_scales=length_scales,
        n_samples=args.n_samples
    )
    
    if not analysis2_results.empty:
        # Add scale0 info to filename if filtered
        scale_suffix = '_scale0' if args.scale0_only else ''
        analysis2_path = output_dir / f'analysis2_results_averaged{scale_suffix}.csv'
        analysis2_results.to_csv(analysis2_path, index=False)
        print(f"Analysis 2 results saved to {analysis2_path}")
    
    # Combine results for easier comparison
    if not analysis1_results.empty and not analysis2_results.empty:
        combined_results = pd.concat([analysis1_results, analysis2_results])
        scale_suffix = '_scale0' if args.scale0_only else ''
        combined_path = output_dir / f'combined_results{scale_suffix}.csv'
        combined_results.to_csv(combined_path, index=False)
        print(f"Combined results saved to {combined_path}")
    
    print("\nAnalysis complete!")

if __name__ == "__main__":
    main()