import torch
import numpy as np
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
import seaborn as sns

class PositionalBiasAnalyzer:
    """
    Utility class for analyzing positional bias in text embeddings.
    """
    
    def __init__(self):
        self.experiment_results = []
    
    def analyze_embedding_similarity(self, original_embeddings: torch.Tensor, 
                                   modified_embeddings: torch.Tensor,
                                   original_prompts: List[str],
                                   modified_prompts: List[str]) -> Dict:
        """
        Analyze similarity between original and modified embeddings.
        
        Args:
            original_embeddings: Embeddings from original prompts
            modified_embeddings: Embeddings from prompts with space prefixes
            original_prompts: List of original prompt texts
            modified_prompts: List of modified prompt texts
            
        Returns:
            Dictionary containing analysis results
        """
        # Calculate cosine similarity between corresponding embeddings
        original_norm = torch.nn.functional.normalize(original_embeddings, dim=-1)
        modified_norm = torch.nn.functional.normalize(modified_embeddings, dim=-1)
        
        similarities = torch.sum(original_norm * modified_norm, dim=-1)
        
        # Calculate L2 distance
        l2_distances = torch.norm(original_embeddings - modified_embeddings, dim=-1)
        
        # Calculate embedding magnitude changes
        original_magnitudes = torch.norm(original_embeddings, dim=-1)
        modified_magnitudes = torch.norm(modified_embeddings, dim=-1)
        magnitude_changes = modified_magnitudes - original_magnitudes
        
        results = {
            'cosine_similarities': similarities.cpu().numpy(),
            'l2_distances': l2_distances.cpu().numpy(),
            'magnitude_changes': magnitude_changes.cpu().numpy(),
            'mean_cosine_similarity': similarities.mean().item(),
            'mean_l2_distance': l2_distances.mean().item(),
            'mean_magnitude_change': magnitude_changes.mean().item(),
            'original_prompts': original_prompts,
            'modified_prompts': modified_prompts
        }
        
        return results
    
    def analyze_positional_shift_impact(self, embeddings: torch.Tensor, 
                                      batch_size: int,
                                      num_references_per_batch: int = 6) -> Dict:
        """
        Analyze how positional shifts affect embeddings across different positions.
        
        Args:
            embeddings: All embeddings (original + modified)
            batch_size: Number of original prompts
            num_references_per_batch: Number of references per batch (default: 6)
            
        Returns:
            Dictionary containing positional analysis results
        """
        # Split embeddings into original and experimental parts
        original_embeddings = embeddings[:batch_size]
        experimental_embeddings = embeddings[batch_size:]
        
        # Group experimental embeddings by their original position
        num_batches = len(experimental_embeddings) // num_references_per_batch
        
        positional_variations = []
        for i in range(num_batches):
            batch_start = i * num_references_per_batch
            batch_end = (i + 1) * num_references_per_batch
            batch_embeddings = experimental_embeddings[batch_start:batch_end]
            
            # Calculate variance within this batch
            batch_variance = torch.var(batch_embeddings, dim=0).mean().item()
            positional_variations.append(batch_variance)
        
        # Calculate overall statistics
        results = {
            'positional_variations': positional_variations,
            'mean_positional_variance': np.mean(positional_variations),
            'std_positional_variance': np.std(positional_variations),
            'max_positional_variance': np.max(positional_variations),
            'min_positional_variance': np.min(positional_variations)
        }
        
        return results
    
    def plot_similarity_analysis(self, analysis_results: Dict, save_path: str = None):
        """
        Create visualization plots for similarity analysis.
        
        Args:
            analysis_results: Results from analyze_embedding_similarity
            save_path: Optional path to save the plot
        """
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Cosine similarity distribution
        axes[0, 0].hist(analysis_results['cosine_similarities'], bins=30, alpha=0.7, color='blue')
        axes[0, 0].axvline(analysis_results['mean_cosine_similarity'], color='red', linestyle='--', 
                          label=f'Mean: {analysis_results["mean_cosine_similarity"]:.3f}')
        axes[0, 0].set_title('Cosine Similarity Distribution')
        axes[0, 0].set_xlabel('Cosine Similarity')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].legend()
        
        # L2 distance distribution
        axes[0, 1].hist(analysis_results['l2_distances'], bins=30, alpha=0.7, color='green')
        axes[0, 1].axvline(analysis_results['mean_l2_distance'], color='red', linestyle='--',
                          label=f'Mean: {analysis_results["mean_l2_distance"]:.3f}')
        axes[0, 1].set_title('L2 Distance Distribution')
        axes[0, 1].set_xlabel('L2 Distance')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].legend()
        
        # Magnitude change distribution
        axes[1, 0].hist(analysis_results['magnitude_changes'], bins=30, alpha=0.7, color='orange')
        axes[1, 0].axvline(analysis_results['mean_magnitude_change'], color='red', linestyle='--',
                          label=f'Mean: {analysis_results["mean_magnitude_change"]:.3f}')
        axes[1, 0].set_title('Magnitude Change Distribution')
        axes[1, 0].set_xlabel('Magnitude Change')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].legend()
        
        # Scatter plot: Cosine similarity vs L2 distance
        axes[1, 1].scatter(analysis_results['cosine_similarities'], analysis_results['l2_distances'], alpha=0.6)
        axes[1, 1].set_xlabel('Cosine Similarity')
        axes[1, 1].set_ylabel('L2 Distance')
        axes[1, 1].set_title('Cosine Similarity vs L2 Distance')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_positional_analysis(self, analysis_results: Dict, save_path: str = None):
        """
        Create visualization plots for positional analysis.
        
        Args:
            analysis_results: Results from analyze_positional_shift_impact
            save_path: Optional path to save the plot
        """
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Positional variance across batches
        axes[0].plot(analysis_results['positional_variations'], marker='o', linewidth=2, markersize=8)
        axes[0].axhline(analysis_results['mean_positional_variance'], color='red', linestyle='--',
                       label=f'Mean: {analysis_results["mean_positional_variance"]:.3f}')
        axes[0].set_title('Positional Variance Across Batches')
        axes[0].set_xlabel('Batch Index')
        axes[0].set_ylabel('Embedding Variance')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Box plot of positional variations
        axes[1].boxplot(analysis_results['positional_variations'])
        axes[1].set_title('Distribution of Positional Variances')
        axes[1].set_ylabel('Embedding Variance')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def generate_experiment_report(self, similarity_results: Dict, positional_results: Dict, 
                                 space_prefix_length: int) -> str:
        """
        Generate a comprehensive experiment report.
        
        Args:
            similarity_results: Results from analyze_embedding_similarity
            positional_results: Results from analyze_positional_shift_impact
            space_prefix_length: Number of spaces added as prefix
            
        Returns:
            Formatted report string
        """
        report = f"""
=== Positional Bias Experiment Report ===
Space Prefix Length: {space_prefix_length}

SIMILARITY ANALYSIS:
- Mean Cosine Similarity: {similarity_results['mean_cosine_similarity']:.4f}
- Mean L2 Distance: {similarity_results['mean_l2_distance']:.4f}
- Mean Magnitude Change: {similarity_results['mean_magnitude_change']:.4f}
- Cosine Similarity Range: [{similarity_results['cosine_similarities'].min():.4f}, {similarity_results['cosine_similarities'].max():.4f}]
- L2 Distance Range: [{similarity_results['l2_distances'].min():.4f}, {similarity_results['l2_distances'].max():.4f}]

POSITIONAL ANALYSIS:
- Mean Positional Variance: {positional_results['mean_positional_variance']:.4f}
- Std Positional Variance: {positional_results['std_positional_variance']:.4f}
- Max Positional Variance: {positional_results['max_positional_variance']:.4f}
- Min Positional Variance: {positional_results['min_positional_variance']:.4f}

INTERPRETATION:
- High cosine similarity (>0.9) suggests minimal positional bias
- Low cosine similarity (<0.7) suggests significant positional bias
- High positional variance indicates that position changes affect embeddings significantly
- Low positional variance suggests embeddings are robust to positional changes

EXPERIMENT SETUP:
- Original prompts: {len(similarity_results['original_prompts'])}
- Modified prompts: {len(similarity_results['modified_prompts'])}
- Space prefix: "{' ' * space_prefix_length}"
        """
        
        return report

def compare_experiments(experiment_results: List[Dict], space_lengths: List[int]) -> Dict:
    """
    Compare results across different space prefix lengths.
    
    Args:
        experiment_results: List of results from different experiments
        space_lengths: List of space prefix lengths used
        
    Returns:
        Comparison results
    """
    comparison = {
        'space_lengths': space_lengths,
        'mean_cosine_similarities': [res['mean_cosine_similarity'] for res in experiment_results],
        'mean_l2_distances': [res['mean_l2_distance'] for res in experiment_results],
        'mean_positional_variances': [res.get('mean_positional_variance', 0) for res in experiment_results]
    }
    
    return comparison

def plot_experiment_comparison(comparison_results: Dict, save_path: str = None):
    """
    Plot comparison across different experiments.
    
    Args:
        comparison_results: Results from compare_experiments
        save_path: Optional path to save the plot
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Cosine similarity vs space length
    axes[0].plot(comparison_results['space_lengths'], comparison_results['mean_cosine_similarities'], 
                marker='o', linewidth=2, markersize=8)
    axes[0].set_title('Cosine Similarity vs Space Length')
    axes[0].set_xlabel('Space Prefix Length')
    axes[0].set_ylabel('Mean Cosine Similarity')
    axes[0].grid(True, alpha=0.3)
    
    # L2 distance vs space length
    axes[1].plot(comparison_results['space_lengths'], comparison_results['mean_l2_distances'], 
                marker='s', linewidth=2, markersize=8, color='orange')
    axes[1].set_title('L2 Distance vs Space Length')
    axes[1].set_xlabel('Space Prefix Length')
    axes[1].set_ylabel('Mean L2 Distance')
    axes[1].grid(True, alpha=0.3)
    
    # Positional variance vs space length
    axes[2].plot(comparison_results['space_lengths'], comparison_results['mean_positional_variances'], 
                marker='^', linewidth=2, markersize=8, color='green')
    axes[2].set_title('Positional Variance vs Space Length')
    axes[2].set_xlabel('Space Prefix Length')
    axes[2].set_ylabel('Mean Positional Variance')
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show() 