import pickle
import torch
import numpy as np
from tqdm import tqdm
from itertools import combinations
import argparse

def gram_linear(x):
    """Compute linear (dot product) Gram matrix"""
    return torch.mm(x, x.T)

def center_gram(gram, unbiased=False):
    """Center the Gram matrix"""
    n = gram.shape[0]
    
    if unbiased and n > 2:
        # Unbiased estimator (for n > 2)
        gram_centered = gram.clone()
        gram_centered.fill_diagonal_(0)
        means = torch.sum(gram_centered, dim=0, dtype=torch.float64) / (n - 2)
        means -= torch.sum(means) / (2 * (n - 1))
        gram_centered -= means[:, None]
        gram_centered -= means[None, :]
        gram_centered.fill_diagonal_(0)
        return gram_centered
    else:
        # Biased estimator (standard CKA)
        one_n = torch.ones(n, n, device=gram.device, dtype=gram.dtype) / n
        return gram - torch.mm(one_n, gram) - torch.mm(gram, one_n) + torch.mm(torch.mm(one_n, gram), one_n)

def calculate_cka_similarity(feat1, feat2, use_gpu=True):
    """
    Calculate CKA similarity between two feature representations
    
    Args:
        feat1: Tensor of shape [batch, spatial_dim, feature_dim1] or [spatial_dim, feature_dim1]
        feat2: Tensor of shape [batch, spatial_dim, feature_dim2] or [spatial_dim, feature_dim2]
        use_gpu: Whether to use GPU for computation
    
    Returns:
        cka_score: Float representing CKA similarity
    """
    # Handle different input shapes
    if len(feat1.shape) == 3:
        feat1 = feat1.squeeze(0)  # Remove batch dimension if present
    if len(feat2.shape) == 3:
        feat2 = feat2.squeeze(0)
    
    # Convert to float for numerical stability
    feat1 = feat1.float()
    feat2 = feat2.float()
    
    # Move to GPU if available and requested
    if use_gpu and torch.cuda.is_available():
        feat1 = feat1.cuda()
        feat2 = feat2.cuda()
    
    try:
        # Compute Gram matrices
        gram_1 = gram_linear(feat1)
        gram_2 = gram_linear(feat2)
        
        # Center the Gram matrices
        gram_1_centered = center_gram(gram_1)
        gram_2_centered = center_gram(gram_2)
        
        # Calculate CKA
        hsic = torch.sum(gram_1_centered * gram_2_centered)
        var1 = torch.sum(gram_1_centered * gram_1_centered)
        var2 = torch.sum(gram_2_centered * gram_2_centered)
        
        # Avoid division by zero
        if var1 > 1e-12 and var2 > 1e-12:
            cka = hsic / torch.sqrt(var1 * var2)
            return cka.item()
        else:
            return 0.0
            
    except Exception as e:
        print(f"Error calculating CKA: {e}")
        return 0.0

def calculate_all_pairwise_cka(features_list, use_gpu=True, sample_size=None):
    """
    Calculate CKA similarity for all encoder pairs across all images
    
    Args:
        features_list: List of feature sets, each containing features from all encoders
        use_gpu: Whether to use GPU
        sample_size: If provided, randomly sample this many images
    
    Returns:
        cka_results: Dictionary with results
    """
    num_images = len(features_list)
    num_encoders = len(features_list[0]) if features_list else 0
    
    print(f"Calculating CKA similarity for {num_images} images, {num_encoders} encoders")
    
    # Generate all encoder pairs
    encoder_pairs = list(combinations(range(num_encoders), 2))
    print(f"Encoder pairs to analyze: {encoder_pairs}")
    
    # Sample images if requested
    if sample_size and sample_size < num_images:
        indices = np.random.choice(num_images, sample_size, replace=False)
        features_list = [features_list[i] for i in sorted(indices)]
        print(f"Sampled {sample_size} images for analysis")
    
    # Initialize storage for CKA scores
    cka_scores = {pair: [] for pair in encoder_pairs}
    
    # Calculate CKA for each image
    print("Calculating CKA similarities...")
    for img_idx, features in enumerate(tqdm(features_list)):
        for pair in encoder_pairs:
            enc1_idx, enc2_idx = pair
            
            try:
                feat1 = features[enc1_idx]
                feat2 = features[enc2_idx]
                
                # Convert to torch tensor if needed
                if not isinstance(feat1, torch.Tensor):
                    feat1 = torch.tensor(feat1)
                if not isinstance(feat2, torch.Tensor):
                    feat2 = torch.tensor(feat2)
                
                # Calculate CKA
                cka_score = calculate_cka_similarity(feat1, feat2, use_gpu)
                cka_scores[pair].append(cka_score)
                
            except Exception as e:
                print(f"Error processing image {img_idx}, pair {pair}: {e}")
                cka_scores[pair].append(0.0)
        
        # Memory cleanup every 100 images
        if (img_idx + 1) % 100 == 0 and use_gpu:
            torch.cuda.empty_cache()
    
    # Calculate statistics
    results = {}
    for pair in encoder_pairs:
        scores = cka_scores[pair]
        valid_scores = [s for s in scores if s > 0]  # Remove failed calculations
        
        if valid_scores:
            results[pair] = {
                'mean': np.mean(valid_scores),
                'std': np.std(valid_scores),
                'min': np.min(valid_scores),
                'max': np.max(valid_scores),
                'median': np.median(valid_scores),
                'valid_count': len(valid_scores),
                'total_count': len(scores),
                'all_scores': valid_scores
            }
        else:
            results[pair] = {
                'mean': 0.0,
                'std': 0.0,
                'min': 0.0,
                'max': 0.0,
                'median': 0.0,
                'valid_count': 0,
                'total_count': len(scores),
                'all_scores': []
            }
    
    return results

def print_cka_results(cka_results):
    """Print formatted CKA results"""
    print("\n" + "="*80)
    print("CKA SIMILARITY RESULTS")
    print("="*80)
    
    # Create similarity matrix for visualization
    encoder_pairs = list(cka_results.keys())
    if encoder_pairs:
        max_encoder = max(max(pair) for pair in encoder_pairs)
        num_encoders = max_encoder + 1
        
        similarity_matrix = np.ones((num_encoders, num_encoders))
        
        for pair, stats in cka_results.items():
            enc1, enc2 = pair
            similarity_matrix[enc1, enc2] = stats['mean']
            similarity_matrix[enc2, enc1] = stats['mean']  # Symmetric
        
        print("\nCKA Similarity Matrix:")
        print("     ", end="")
        for i in range(num_encoders):
            print(f"Enc{i:2d}", end="  ")
        print()
        
        for i in range(num_encoders):
            print(f"Enc{i:2d}", end="  ")
            for j in range(num_encoders):
                print(f"{similarity_matrix[i,j]:.3f}", end=" ")
            print()
    
    print("\nDetailed Statistics:")
    for pair, stats in cka_results.items():
        enc1, enc2 = pair
        print(f"\nEncoders {enc1} ↔ {enc2}:")
        print(f"  Mean CKA: {stats['mean']:.4f} ± {stats['std']:.4f}")
        print(f"  Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
        print(f"  Median: {stats['median']:.4f}")
        print(f"  Valid calculations: {stats['valid_count']}/{stats['total_count']}")
    
    # Interpretation
    print("\n" + "="*50)
    print("INTERPRETATION GUIDE:")
    print("="*50)
    print("CKA > 0.8: Very similar representations (highly redundant)")
    print("CKA 0.6-0.8: Moderately similar (some redundancy)")
    print("CKA 0.3-0.6: Low similarity (complementary features)")
    print("CKA < 0.3: Very different (highly complementary)")

def save_cka_results(cka_results, output_path):
    """Save CKA results to pickle file"""
    with open(output_path, 'wb') as f:
        pickle.dump(cka_results, f)
    print(f"\nResults saved to {output_path}")

def main():
    parser = argparse.ArgumentParser(description='Calculate CKA similarity between encoders')
    parser.add_argument('features_file', help='Path to combined features pickle file')
    parser.add_argument('--output', '-o', help='Output file for results', 
                       default='cka_similarity_results.pkl')
    parser.add_argument('--sample', '-s', type=int, help='Sample size (if less than total images)')
    parser.add_argument('--no-gpu', action='store_true', help='Disable GPU usage')
    parser.add_argument('--save-scores', action='store_true', help='Save individual scores (not just stats)')
    
    args = parser.parse_args()
    
    # Load features
    print(f"Loading features from {args.features_file}...")
    with open(args.features_file, 'rb') as f:
        features_list = pickle.load(f)
    
    print(f"Loaded {len(features_list)} feature sets")
    if features_list:
        print(f"Each set contains {len(features_list[0])} encoder features")
    
    # Calculate CKA similarities
    use_gpu = not args.no_gpu
    cka_results = calculate_all_pairwise_cka(
        features_list=features_list,
        use_gpu=use_gpu,
        sample_size=args.sample
    )
    
    # Print results
    print_cka_results(cka_results)
    
    # Save results
    if not args.save_scores:
        # Remove individual scores to save space
        for pair in cka_results:
            del cka_results[pair]['all_scores']
    
    save_cka_results(cka_results, args.output)

if __name__ == "__main__":
    import sys
    
    if len(sys.argv) > 1:
        main()
    else:
        # Default usage for testing
        features_file = "MER/VLMEvalKit/features/converted_features_eagle_x4_mme.pkl"
        
        print("Loading features...")
        with open(features_file, 'rb') as f:
            features_list = pickle.load(f)
        
        print(f"Loaded {len(features_list)} feature sets")
        
        # Calculate CKA for all images
        cka_results = calculate_all_pairwise_cka(features_list, use_gpu=True)
        
        # Print and save results
        print_cka_results(cka_results)
        save_cka_results(cka_results, 'cka_similarity_results.pkl')