import pickle
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from itertools import combinations
import argparse

def calculate_cosine_similarity(feat1, feat2, pooling_method='mean', use_gpu=True):
    """
    Calculate cosine similarity between two feature representations
    
    Args:
        feat1: Tensor of shape [batch, spatial_dim, feature_dim1] or [spatial_dim, feature_dim1]
        feat2: Tensor of shape [batch, spatial_dim, feature_dim2] or [spatial_dim, feature_dim2]
        pooling_method: How to pool spatial dimensions ('mean', 'max', 'sum')
        use_gpu: Whether to use GPU for computation
    
    Returns:
        cosine_score: Float representing cosine similarity
    """
    # Handle different input shapes
    if len(feat1.shape) == 3:
        feat1 = feat1.squeeze(0)  # Remove batch dimension if present
    if len(feat2.shape) == 3:
        feat2 = feat2.squeeze(0)
    
    # Convert to float for numerical stability
    feat1 = feat1.float()
    feat2 = feat2.float()
    
    # Move to GPU if available and requested
    if use_gpu and torch.cuda.is_available():
        feat1 = feat1.cuda()
        feat2 = feat2.cuda()
    
    try:
        # Pool spatial dimensions to get feature vectors
        if pooling_method == 'mean':
            vec1 = feat1.mean(dim=0)  # [feature_dim1]
            vec2 = feat2.mean(dim=0)  # [feature_dim2]
        elif pooling_method == 'max':
            vec1 = feat1.max(dim=0)[0]  # [feature_dim1]
            vec2 = feat2.max(dim=0)[0]  # [feature_dim2]
        elif pooling_method == 'sum':
            vec1 = feat1.sum(dim=0)  # [feature_dim1]
            vec2 = feat2.sum(dim=0)  # [feature_dim2]
        else:
            raise ValueError(f"Unknown pooling method: {pooling_method}")
        
        # For different dimensionalities, we need to use other methods
        if vec1.shape[0] != vec2.shape[0]:
            # Method: Use the correlation of their norms across spatial dimension
            norms1 = torch.norm(feat1, dim=1)  # [spatial_dim]
            norms2 = torch.norm(feat2, dim=1)  # [spatial_dim]
            
            # Calculate correlation coefficient as similarity
            mean1, mean2 = norms1.mean(), norms2.mean()
            centered1 = norms1 - mean1
            centered2 = norms2 - mean2
            
            numerator = torch.sum(centered1 * centered2)
            denominator = torch.sqrt(torch.sum(centered1**2) * torch.sum(centered2**2))
            
            if denominator > 1e-8:
                correlation = numerator / denominator
                return correlation.item()
            else:
                return 0.0
        else:
            # Same dimensionality - direct cosine similarity
            cosine_sim = F.cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0), dim=1)
            return cosine_sim.item()
            
    except Exception as e:
        print(f"Error calculating cosine similarity: {e}")
        return 0.0

def calculate_spatial_cosine_similarity(feat1, feat2, use_gpu=True):
    """
    Calculate cosine similarity using spatial pattern correlation
    """
    try:
        # Move to GPU if available and requested
        if use_gpu and torch.cuda.is_available():
            feat1 = feat1.cuda()
            feat2 = feat2.cuda()
        
        # Convert to float32 for numerical stability
        feat1 = feat1.float()
        feat2 = feat2.float()
        
        # Calculate average activation per spatial location (this works for any feature dim)
        spatial_pattern1 = feat1.mean(dim=1)  # [spatial_dim] - average across features
        spatial_pattern2 = feat2.mean(dim=1)  # [spatial_dim] - average across features
        
        # Both patterns now have same shape [spatial_dim], so we can compare them
        cos_sim = F.cosine_similarity(
            spatial_pattern1.unsqueeze(0), 
            spatial_pattern2.unsqueeze(0), 
            dim=1
        )
        return cos_sim.item()
        
    except Exception as e:
        print(f"Error in spatial cosine similarity: {e}")
        # Fallback: use norm correlation
        try:
            # Calculate L2 norm across feature dimension for each spatial location
            norms1 = torch.norm(feat1, dim=1)  # [spatial_dim]
            norms2 = torch.norm(feat2, dim=1)  # [spatial_dim]
            
            # Calculate correlation coefficient between the norm patterns
            mean1, mean2 = norms1.mean(), norms2.mean()
            centered1 = norms1 - mean1
            centered2 = norms2 - mean2
            
            numerator = torch.sum(centered1 * centered2)
            denominator = torch.sqrt(torch.sum(centered1**2) * torch.sum(centered2**2))
            
            if denominator > 1e-8:
                correlation = numerator / denominator
                return correlation.item()
            else:
                return 0.0
        except Exception as e2:
            print(f"Fallback also failed: {e2}")
            return 0.0

def calculate_cca_based_similarity(feat1, feat2, num_components=50, use_gpu=True):
    """
    Calculate similarity using CCA-like approach for different dimensions
    """
    try:
        # Move to GPU if available and requested
        if use_gpu and torch.cuda.is_available():
            feat1 = feat1.cuda()
            feat2 = feat2.cuda()
        
        # Convert to float32 (required for SVD)
        feat1 = feat1.float()
        feat2 = feat2.float()
        
        # Center the features
        feat1_centered = feat1 - feat1.mean(dim=0, keepdim=True)
        feat2_centered = feat2 - feat2.mean(dim=0, keepdim=True)
        
        # For numerical stability, add small regularization
        reg = 1e-6
        feat1_centered = feat1_centered + reg * torch.randn_like(feat1_centered)
        feat2_centered = feat2_centered + reg * torch.randn_like(feat2_centered)
        
        # Project to same dimension using SVD
        try:
            U1, S1, V1 = torch.svd(feat1_centered)
            U2, S2, V2 = torch.svd(feat2_centered)
        except RuntimeError as svd_error:
            # Fallback to CPU if GPU SVD fails
            print(f"GPU SVD failed, trying CPU: {svd_error}")
            feat1_cpu = feat1_centered.cpu()
            feat2_cpu = feat2_centered.cpu()
            U1, S1, V1 = torch.svd(feat1_cpu)
            U2, S2, V2 = torch.svd(feat2_cpu)
            if use_gpu and torch.cuda.is_available():
                U1, U2 = U1.cuda(), U2.cuda()
        
        # Take top components
        k = min(num_components, min(U1.shape[1], U2.shape[1]), U1.shape[0])
        proj1 = U1[:, :k]  # [spatial_dim, k]
        proj2 = U2[:, :k]  # [spatial_dim, k]
        
        # Calculate cosine similarity between projections
        similarities = []
        for i in range(k):
            sim = F.cosine_similarity(
                proj1[:, i].unsqueeze(0), 
                proj2[:, i].unsqueeze(0), 
                dim=1
            )
            similarities.append(abs(sim.item()))  # Take absolute value
        
        # Average similarity across components
        return np.mean(similarities) if similarities else 0.0
        
    except Exception as e:
        print(f"Error in CCA-based similarity: {e}")
        # Fallback to simple correlation of feature norms
        try:
            # Convert to float32 and move to appropriate device
            feat1 = feat1.float()
            feat2 = feat2.float()
            
            if use_gpu and torch.cuda.is_available():
                feat1 = feat1.cuda()
                feat2 = feat2.cuda()
            
            # Calculate feature-wise norms
            norms1 = torch.norm(feat1, dim=0)  # [feature_dim1]
            norms2 = torch.norm(feat2, dim=0)  # [feature_dim2]
            
            # Take statistics of norms as simple similarity
            stat1 = torch.stack([norms1.mean(), norms1.std(), norms1.max(), norms1.min()])
            stat2 = torch.stack([norms2.mean(), norms2.std(), norms2.max(), norms2.min()])
            
            similarity = F.cosine_similarity(stat1.unsqueeze(0), stat2.unsqueeze(0), dim=1)
            return similarity.item()
            
        except Exception as e2:
            print(f"Fallback similarity calculation failed: {e2}")
            return 0.0

def calculate_robust_cca_similarity(feat1, feat2, use_gpu=True):
    """
    Robust CCA-based similarity that handles different dimensions better
    """
    try:
        # Convert to float32 and move to device
        feat1 = feat1.float()
        feat2 = feat2.float()
        
        if use_gpu and torch.cuda.is_available():
            feat1 = feat1.cuda()
            feat2 = feat2.cuda()
        
        # Method: Compare distributions of activations
        # Calculate statistics across spatial dimensions
        
        # For each feature dimension, calculate statistics across spatial locations
        feat1_stats = torch.stack([
            feat1.mean(dim=0),  # mean across spatial dim: [feature_dim1]
            feat1.std(dim=0),   # std across spatial dim: [feature_dim1]
            feat1.max(dim=0)[0],# max across spatial dim: [feature_dim1]
            feat1.min(dim=0)[0] # min across spatial dim: [feature_dim1]
        ])  # [4, feature_dim1]
        
        feat2_stats = torch.stack([
            feat2.mean(dim=0),  # mean across spatial dim: [feature_dim2]
            feat2.std(dim=0),   # std across spatial dim: [feature_dim2]
            feat2.max(dim=0)[0],# max across spatial dim: [feature_dim2]
            feat2.min(dim=0)[0] # min across spatial dim: [feature_dim2]
        ])  # [4, feature_dim2]
        
        # Now we have [4, feature_dim1] and [4, feature_dim2]
        # Calculate summary statistics of these statistics to get same-dimension vectors
        summary1 = torch.stack([
            feat1_stats.mean(dim=1).mean(),  # overall mean
            feat1_stats.std(dim=1).mean(),   # average std
            feat1_stats.max(dim=1)[0].mean(),# average max  
            feat1_stats.min(dim=1)[0].mean(),# average min
            feat1_stats.mean(dim=0).std(),   # variability across features for mean
            feat1_stats.std(dim=0).std(),    # variability across features for std
        ])  # [6] - same size for both encoders
        
        summary2 = torch.stack([
            feat2_stats.mean(dim=1).mean(),  # overall mean
            feat2_stats.std(dim=1).mean(),   # average std
            feat2_stats.max(dim=1)[0].mean(),# average max
            feat2_stats.min(dim=1)[0].mean(),# average min
            feat2_stats.mean(dim=0).std(),   # variability across features for mean
            feat2_stats.std(dim=0).std(),    # variability across features for std
        ])  # [6] - same size for both encoders
        
        # Now we can compute cosine similarity
        similarity = F.cosine_similarity(
            summary1.unsqueeze(0), 
            summary2.unsqueeze(0), 
            dim=1
        )
        return similarity.item()
            
    except Exception as e:
        print(f"Error in robust CCA similarity: {e}")
        # Final fallback: compare just the norms
        try:
            norm1 = torch.norm(feat1).item()
            norm2 = torch.norm(feat2).item()
            # Return normalized difference (higher = more similar)
            diff = abs(norm1 - norm2) / (norm1 + norm2 + 1e-8)
            return 1.0 - diff  # Convert to similarity
        except:
            return 0.0

def calculate_dimension_agnostic_similarity(feat1, feat2, use_gpu=True):
    """
    Calculate similarity that works regardless of feature dimensions
    """
    try:
        # Convert to float32 and move to device
        feat1 = feat1.float()
        feat2 = feat2.float()
        
        if use_gpu and torch.cuda.is_available():
            feat1 = feat1.cuda()
            feat2 = feat2.cuda()
        
        # Method 1: Compare spatial activation patterns
        # Get activation strength at each spatial location
        spatial_strength1 = torch.norm(feat1, dim=1)  # [spatial_dim]
        spatial_strength2 = torch.norm(feat2, dim=1)  # [spatial_dim]
        
        # Cosine similarity between spatial patterns
        spatial_sim = F.cosine_similarity(
            spatial_strength1.unsqueeze(0),
            spatial_strength2.unsqueeze(0),
            dim=1
        ).item()
        
        # Method 2: Compare feature activation distributions
        # Calculate statistics that are dimension-agnostic
        
        # Global statistics
        global_stats1 = torch.tensor([
            feat1.mean().item(),
            feat1.std().item(), 
            feat1.max().item(),
            feat1.min().item(),
            torch.norm(feat1).item(),
        ])
        
        global_stats2 = torch.tensor([
            feat2.mean().item(),
            feat2.std().item(),
            feat2.max().item(), 
            feat2.min().item(),
            torch.norm(feat2).item(),
        ])
        
        # Normalize statistics to make them comparable
        global_stats1 = global_stats1 / (torch.norm(global_stats1) + 1e-8)
        global_stats2 = global_stats2 / (torch.norm(global_stats2) + 1e-8)
        
        global_sim = F.cosine_similarity(
            global_stats1.unsqueeze(0),
            global_stats2.unsqueeze(0),
            dim=1
        ).item()
        
        # Method 3: Compare feature dimension statistics
        # Calculate per-feature statistics and then summarize
        feat_stats1 = torch.stack([
            feat1.mean(dim=0),  # [feature_dim1]
            feat1.std(dim=0),   # [feature_dim1]
        ])  # [2, feature_dim1]
        
        feat_stats2 = torch.stack([
            feat2.mean(dim=0),  # [feature_dim2] 
            feat2.std(dim=0),   # [feature_dim2]
        ])  # [2, feature_dim2]
        
        # Summarize to same dimension
        feat_summary1 = torch.tensor([
            feat_stats1[0].mean().item(),  # avg mean across features
            feat_stats1[0].std().item(),   # variability of means
            feat_stats1[1].mean().item(),  # avg std across features
            feat_stats1[1].std().item(),   # variability of stds
        ])
        
        feat_summary2 = torch.tensor([
            feat_stats2[0].mean().item(),  # avg mean across features
            feat_stats2[0].std().item(),   # variability of means
            feat_stats2[1].mean().item(),  # avg std across features
            feat_stats2[1].std().item(),   # variability of stds
        ])
        
        feat_sim = F.cosine_similarity(
            feat_summary1.unsqueeze(0),
            feat_summary2.unsqueeze(0),
            dim=1
        ).item()
        
        # Combine all similarities
        combined_sim = (spatial_sim + global_sim + feat_sim) / 3.0
        return combined_sim
        
    except Exception as e:
        print(f"Error in dimension agnostic similarity: {e}")
        return 0.0
    

def calculate_all_pairwise_cosine(features_list, methods=['mean_pooling', 'spatial_pattern', 'cca_based'], 
                                 use_gpu=True, sample_size=None):
    """
    Calculate cosine similarity for all encoder pairs across all images using multiple methods
    
    Args:
        features_list: List of feature sets, each containing features from all encoders
        methods: List of similarity calculation methods to use
        use_gpu: Whether to use GPU
        sample_size: If provided, randomly sample this many images
    
    Returns:
        cosine_results: Dictionary with results for each method
    """
    num_images = len(features_list)
    num_encoders = len(features_list[0]) if features_list else 0
    
    print(f"Calculating cosine similarity for {num_images} images, {num_encoders} encoders")
    
    # Generate all encoder pairs
    encoder_pairs = list(combinations(range(num_encoders), 2))
    print(f"Encoder pairs to analyze: {encoder_pairs}")
    print(f"Methods to use: {methods}")
    
    # Sample images if requested
    if sample_size and sample_size < num_images:
        indices = np.random.choice(num_images, sample_size, replace=False)
        features_list = [features_list[i] for i in sorted(indices)]
        print(f"Sampled {sample_size} images for analysis")
    
    # Initialize storage for cosine scores
    cosine_results = {}
    for method in methods:
        cosine_results[method] = {pair: [] for pair in encoder_pairs}
    
    # Calculate cosine similarity for each image
    print("Calculating cosine similarities...")
    for img_idx, features in enumerate(tqdm(features_list)):
        for pair in encoder_pairs:
            enc1_idx, enc2_idx = pair
            
            try:
                feat1 = features[enc1_idx]
                feat2 = features[enc2_idx]
                
                # Convert to torch tensor if needed
                if not isinstance(feat1, torch.Tensor):
                    feat1 = torch.tensor(feat1)
                if not isinstance(feat2, torch.Tensor):
                    feat2 = torch.tensor(feat2)
                
                # Calculate similarity using different methods
                for method in methods:
                    if method == 'mean_pooling':
                        score = calculate_cosine_similarity(feat1, feat2, 'mean', use_gpu)
                    elif method == 'max_pooling':
                        score = calculate_cosine_similarity(feat1, feat2, 'max', use_gpu)
                    elif method == 'spatial_pattern':
                        score = calculate_spatial_cosine_similarity(feat1, feat2, use_gpu)
                    elif method == 'cca_based':
                        score = calculate_cca_based_similarity(feat1, feat2, 50, use_gpu)
                    elif method == 'robust_cca':
                        score = calculate_robust_cca_similarity(feat1, feat2, use_gpu)
                    elif method == 'dimension_agnostic':  # Add this new robust method
                        score = calculate_dimension_agnostic_similarity(feat1, feat2, use_gpu)
                    else:
                        score = 0.0
                    
                    cosine_results[method][pair].append(score)
                
            except Exception as e:
                print(f"Error processing image {img_idx}, pair {pair}: {e}")
                for method in methods:
                    cosine_results[method][pair].append(0.0)
        
        # Memory cleanup every 100 images
        if (img_idx + 1) % 100 == 0 and use_gpu:
            torch.cuda.empty_cache()
    
    # Calculate statistics for each method
    final_results = {}
    for method in methods:
        final_results[method] = {}
        for pair in encoder_pairs:
            scores = cosine_results[method][pair]
            valid_scores = [s for s in scores if not np.isnan(s) and abs(s) < 10]  # Remove invalid scores
            
            if valid_scores:
                final_results[method][pair] = {
                    'mean': np.mean(valid_scores),
                    'std': np.std(valid_scores),
                    'min': np.min(valid_scores),
                    'max': np.max(valid_scores),
                    'median': np.median(valid_scores),
                    'valid_count': len(valid_scores),
                    'total_count': len(scores),
                    'all_scores': valid_scores
                }
            else:
                final_results[method][pair] = {
                    'mean': 0.0,
                    'std': 0.0,
                    'min': 0.0,
                    'max': 0.0,
                    'median': 0.0,
                    'valid_count': 0,
                    'total_count': len(scores),
                    'all_scores': []
                }
    
    return final_results

def print_cosine_results(cosine_results):
    """Print formatted cosine similarity results"""
    print("\n" + "="*80)
    print("COSINE SIMILARITY RESULTS")
    print("="*80)
    
    methods = list(cosine_results.keys())
    
    for method in methods:
        print(f"\n--- {method.upper().replace('_', ' ')} ---")
        
        # Create similarity matrix for visualization
        encoder_pairs = list(cosine_results[method].keys())
        if encoder_pairs:
            max_encoder = max(max(pair) for pair in encoder_pairs)
            num_encoders = max_encoder + 1
            
            similarity_matrix = np.ones((num_encoders, num_encoders))
            
            for pair, stats in cosine_results[method].items():
                enc1, enc2 = pair
                similarity_matrix[enc1, enc2] = stats['mean']
                similarity_matrix[enc2, enc1] = stats['mean']  # Symmetric
            
            print("\nSimilarity Matrix:")
            print("     ", end="")
            for i in range(num_encoders):
                print(f"Enc{i:2d}", end="  ")
            print()
            
            for i in range(num_encoders):
                print(f"Enc{i:2d}", end="  ")
                for j in range(num_encoders):
                    print(f"{similarity_matrix[i,j]:.3f}", end=" ")
                print()
        
        print("\nDetailed Statistics:")
        for pair, stats in cosine_results[method].items():
            enc1, enc2 = pair
            print(f"  Encoders {enc1} ↔ {enc2}: {stats['mean']:.4f} ± {stats['std']:.4f} "
                  f"[{stats['min']:.4f}, {stats['max']:.4f}] (n={stats['valid_count']})")
    
    # Interpretation
    print("\n" + "="*60)
    print("INTERPRETATION GUIDE:")
    print("="*60)
    print("Cosine > 0.8: Very similar representations (highly redundant)")
    print("Cosine 0.5-0.8: Moderately similar (some redundancy)")
    print("Cosine 0.2-0.5: Low similarity (complementary features)")  
    print("Cosine < 0.2: Very different (highly complementary)")
    print("\nNote: Different methods may give different scales")
    print("- Mean/Max pooling: Direct feature comparison")
    print("- Spatial pattern: Activation pattern similarity")
    print("- CCA-based: Handles different dimensions better")

def save_cosine_results(cosine_results, output_path):
    """Save cosine similarity results to pickle file"""
    with open(output_path, 'wb') as f:
        pickle.dump(cosine_results, f)
    print(f"\nResults saved to {output_path}")

def main():
    parser = argparse.ArgumentParser(description='Calculate cosine similarity between encoders')
    parser.add_argument('features_file', help='Path to combined features pickle file')
    parser.add_argument('--output', '-o', help='Output file for results', 
                       default='cosine_similarity_results.pkl')
    parser.add_argument('--sample', '-s', type=int, help='Sample size (if less than total images)')
    parser.add_argument('--methods', '-m', nargs='+', 
                    choices=['mean_pooling', 'max_pooling', 'spatial_pattern', 'cca_based', 'robust_cca', 'dimension_agnostic'],
                    default=['mean_pooling', 'spatial_pattern', 'dimension_agnostic'],
                    help='Similarity calculation methods to use')
    parser.add_argument('--no-gpu', action='store_true', help='Disable GPU usage')
    parser.add_argument('--save-scores', action='store_true', help='Save individual scores (not just stats)')
    
    args = parser.parse_args()
    
    # Load features
    print(f"Loading features from {args.features_file}...")
    with open(args.features_file, 'rb') as f:
        features_list = pickle.load(f)
    
    print(f"Loaded {len(features_list)} feature sets")
    if features_list:
        print(f"Each set contains {len(features_list[0])} encoder features")
        print("Feature shapes:")
        for i, feat in enumerate(features_list[0]):
            print(f"  Encoder {i}: {feat.shape}")
    
    # Calculate cosine similarities
    use_gpu = not args.no_gpu
    cosine_results = calculate_all_pairwise_cosine(
        features_list=features_list,
        methods=args.methods,
        use_gpu=use_gpu,
        sample_size=args.sample
    )
    
    # Print results
    print_cosine_results(cosine_results)
    
    # Save results
    if not args.save_scores:
        # Remove individual scores to save space
        for method in cosine_results:
            for pair in cosine_results[method]:
                del cosine_results[method][pair]['all_scores']
    
    save_cosine_results(cosine_results, args.output)

if __name__ == "__main__":
    import sys
    
    if len(sys.argv) > 1:
        main()
    else:
        # Default usage for testing
        features_file = "MER/VLMEvalKit/features/converted_features_eagle_x4_mme.pkl"
        
        print("Loading features...")
        with open(features_file, 'rb') as f:
            features_list = pickle.load(f)
        
        print(f"Loaded {len(features_list)} feature sets")
        
        # Calculate cosine similarity using all methods
        cosine_results = calculate_all_pairwise_cosine(
            features_list, 
            methods=['mean_pooling', 'spatial_pattern', 'dimension_agnostic'],
            use_gpu=True
        )
        
        # Print and save results
        print_cosine_results(cosine_results)
        save_cosine_results(cosine_results, 'cosine_similarity_results.pkl')