#!/usr/bin/env python3
"""
Zipf Distribution Evaluation for Train vs Test Comparison
Calculates and compares Zipf distributions between training and test datasets
"""

import json
import pandas as pd
import numpy as np
from scipy.spatial.distance import jensenshannon
from scipy import stats
from typing import Dict, List, Any, Tuple
import argparse
import math
from collections import Counter

class ZipfDistributionEvaluator:
    """Evaluates Zipf distributions between train and test datasets"""
    
    def __init__(self):
        pass
    
    def calculate_zipf_parameters(self, frequencies: List[int]) -> Tuple[float, float, float]:
        """Calculate Zipf parameters from frequency data
        
        Args:
            frequencies: List of frequencies sorted by rank (descending)
            
        Returns:
            Tuple of (s_exponent, C_constant, r_squared)
        """
        if not frequencies:
            return 0.0, 0.0, 0.0
        
        n = len(frequencies)
        ranks = list(range(1, n + 1))
        
        # Linear regression on log-log plot
        log_ranks = np.log(ranks)
        log_freq = np.log(frequencies)
        
        # Fit: log(f) = log(C) - s * log(r)
        slope, intercept, r_value, p_value, std_err = stats.linregress(log_ranks, log_freq)
        
        s_exponent = -slope  # Negative because slope is negative
        C_constant = np.exp(intercept)
        r_squared = r_value ** 2
        
        return s_exponent, C_constant, r_squared
    
    def create_frequency_distribution(self, codes: List[str]) -> Dict[str, int]:
        """Create frequency distribution from list of codes"""
        return dict(Counter(codes))
    
    def calculate_zipf_distribution(self, freq_dist: Dict[str, int], s: float = 1.0) -> Dict[str, float]:
        """Calculate theoretical Zipf distribution
        
        Args:
            freq_dist: Frequency distribution
            s: Zipf exponent (default 1.0)
            
        Returns:
            Theoretical Zipf probability distribution
        """
        if not freq_dist:
            return {}
        
        # Sort by frequency (descending)
        sorted_items = sorted(freq_dist.items(), key=lambda x: x[1], reverse=True)
        n = len(sorted_items)
        
        # Calculate normalization constant
        if s == 1.0:
            # Use harmonic number for s=1
            harmonic_n = sum(1/k for k in range(1, n + 1))
            C = 1.0 / harmonic_n
        else:
            # Use generalized harmonic number
            harmonic_s = sum(1/(k**s) for k in range(1, n + 1))
            C = 1.0 / harmonic_s
        
        # Create Zipf distribution
        zipf_dist = {}
        for rank, (item, freq) in enumerate(sorted_items, 1):
            zipf_prob = C / (rank ** s)
            zipf_dist[item] = zipf_prob
        
        return zipf_dist
    
    def compare_zipf_distributions(self, train_codes: List[str], test_codes: List[str]) -> Dict[str, Any]:
        """Compare Zipf distributions between train and test datasets"""
        
        print("📊 Calculating frequency distributions...")
        train_freq = self.create_frequency_distribution(train_codes)
        test_freq = self.create_frequency_distribution(test_codes)
        
        print(f"   Train: {len(train_codes)} total codes, {len(train_freq)} unique codes")
        print(f"   Test: {len(test_codes)} total codes, {len(test_freq)} unique codes")
        
        # Get sorted frequencies for Zipf parameter calculation
        train_sorted_freq = sorted(train_freq.values(), reverse=True)
        test_sorted_freq = sorted(test_freq.values(), reverse=True)
        
        print("📊 Calculating Zipf parameters...")
        train_s, train_C, train_r2 = self.calculate_zipf_parameters(train_sorted_freq)
        test_s, test_C, test_r2 = self.calculate_zipf_parameters(test_sorted_freq)
        
        print(f"   Train Zipf: s={train_s:.4f}, C={train_C:.4f}, R²={train_r2:.4f}")
        print(f"   Test Zipf: s={test_s:.4f}, C={test_C:.4f}, R²={test_r2:.4f}")
        
        # Calculate theoretical Zipf distributions
        train_zipf = self.calculate_zipf_distribution(train_freq, train_s)
        test_zipf = self.calculate_zipf_distribution(test_freq, test_s)
        
        # Calculate distribution distances
        print("📊 Calculating distribution distances...")
        distances = self._calculate_distribution_distances(train_freq, test_freq)
        
        # Calculate Zipf-specific metrics
        zipf_metrics = self._calculate_zipf_metrics(train_s, test_s, train_r2, test_r2)
        
        results = {
            'zipf_parameters': {
                'train': {
                    's_exponent': train_s,
                    'C_constant': train_C,
                    'r_squared': train_r2,
                    'total_codes': len(train_codes),
                    'unique_codes': len(train_freq)
                },
                'test': {
                    's_exponent': test_s,
                    'C_constant': test_C,
                    'r_squared': test_r2,
                    'total_codes': len(test_codes),
                    'unique_codes': len(test_freq)
                }
            },
            'distribution_distances': distances,
            'zipf_metrics': zipf_metrics,
            'interpretation': self._interpret_zipf_results(train_s, test_s, train_r2, test_r2, distances)
        }
        
        return results
    
    def _calculate_distribution_distances(self, train_freq: Dict[str, int], test_freq: Dict[str, int]) -> Dict[str, float]:
        """Calculate various distance metrics between distributions"""
        
        # Get all unique codes
        all_codes = set(train_freq.keys()) | set(test_freq.keys())
        
        # Create frequency vectors
        train_vec = [train_freq.get(code, 0) for code in all_codes]
        test_vec = [test_freq.get(code, 0) for code in all_codes]
        
        # Normalize to probabilities
        train_total = sum(train_vec)
        test_total = sum(test_vec)
        train_prob = [f/train_total for f in train_vec] if train_total > 0 else train_vec
        test_prob = [f/test_total for f in test_vec] if test_total > 0 else test_vec
        
        # 1. Jensen-Shannon Divergence
        jsd = jensenshannon(train_prob, test_prob)
        
        # 2. Kullback-Leibler Divergence
        epsilon = 1e-10
        train_prob_safe = [p + epsilon for p in train_prob]
        test_prob_safe = [p + epsilon for p in test_prob]
        
        kl_div = sum(p * math.log(p/q) for p, q in zip(train_prob_safe, test_prob_safe))
        
        # 3. Chi-square distance
        chi_square = sum((p - q)**2 / (p + q + epsilon) for p, q in zip(train_prob, test_prob))
        
        # 4. Earth Mover's Distance (simplified)
        train_sorted = sorted(train_prob, reverse=True)
        test_sorted = sorted(test_prob, reverse=True)
        
        max_len = max(len(train_sorted), len(test_sorted))
        train_padded = train_sorted + [0] * (max_len - len(train_sorted))
        test_padded = test_sorted + [0] * (max_len - len(test_sorted))
        
        emd = sum(abs(p - q) for p, q in zip(train_padded, test_padded))
        
        return {
            'jensen_shannon_divergence': jsd,
            'kullback_leibler_divergence': kl_div,
            'chi_square_distance': chi_square,
            'earth_mover_distance': emd
        }
    
    def _calculate_zipf_metrics(self, train_s: float, test_s: float, train_r2: float, test_r2: float) -> Dict[str, Any]:
        """Calculate Zipf-specific comparison metrics"""
        
        # Exponent difference
        s_difference = abs(train_s - test_s)
        
        # R-squared difference
        r2_difference = abs(train_r2 - test_r2)
        
        # Zipf quality assessment
        train_zipf_quality = "Excellent" if train_r2 > 0.9 else "Good" if train_r2 > 0.8 else "Moderate" if train_r2 > 0.7 else "Poor"
        test_zipf_quality = "Excellent" if test_r2 > 0.9 else "Good" if test_r2 > 0.8 else "Moderate" if test_r2 > 0.7 else "Poor"
        
        return {
            'exponent_difference': s_difference,
            'r_squared_difference': r2_difference,
            'train_zipf_quality': train_zipf_quality,
            'test_zipf_quality': test_zipf_quality,
            'both_follow_zipf': bool(train_r2 > 0.7 and test_r2 > 0.7),
            'similar_exponents': bool(s_difference < 0.2)
        }
    
    def _interpret_zipf_results(self, train_s: float, test_s: float, train_r2: float, test_r2: float, distances: Dict[str, float]) -> Dict[str, str]:
        """Interpret the Zipf comparison results"""
        
        interpretation = {}
        
        # Overall distribution similarity
        jsd = distances['jensen_shannon_divergence']
        if jsd < 0.1:
            interpretation['distribution_similarity'] = "Very similar distributions"
        elif jsd < 0.3:
            interpretation['distribution_similarity'] = "Moderately similar distributions"
        elif jsd < 0.5:
            interpretation['distribution_similarity'] = "Moderately different distributions"
        else:
            interpretation['distribution_similarity'] = "Very different distributions"
        
        # Zipf adherence
        if train_r2 > 0.8 and test_r2 > 0.8:
            interpretation['zipf_adherence'] = "Both datasets follow Zipf distribution well"
        elif train_r2 > 0.7 or test_r2 > 0.7:
            interpretation['zipf_adherence'] = "One or both datasets moderately follow Zipf distribution"
        else:
            interpretation['zipf_adherence'] = "Neither dataset follows Zipf distribution well"
        
        # Exponent similarity
        s_diff = abs(train_s - test_s)
        if s_diff < 0.1:
            interpretation['exponent_similarity'] = "Very similar Zipf exponents"
        elif s_diff < 0.3:
            interpretation['exponent_similarity'] = "Moderately similar Zipf exponents"
        else:
            interpretation['exponent_similarity'] = "Different Zipf exponents"
        
        return interpretation

def main():
    parser = argparse.ArgumentParser(description='Evaluate Zipf distributions between train and test datasets')
    parser.add_argument('--train_corpus', type=str, required=True, 
                       help='Path to training corpus parquet file')
    parser.add_argument('--test_corpus', type=str, required=True,
                       help='Path to test corpus parquet file')
    parser.add_argument('--output', type=str, default='zipf_distribution_results.json',
                       help='Output file path')
    
    args = parser.parse_args()
    
    # Load data
    print("📂 Loading datasets...")
    train_df = pd.read_parquet(args.train_corpus)
    test_df = pd.read_parquet(args.test_corpus)
    
    # Extract codes (handle both 'tag' and 'code' columns)
    if 'tag' in train_df.columns:
        train_codes = train_df['tag'].tolist()
    elif 'code' in train_df.columns:
        train_codes = train_df['code'].tolist()
    else:
        raise ValueError("Training corpus must have either 'tag' or 'code' column")
    
    if 'tag' in test_df.columns:
        test_codes = test_df['tag'].tolist()
    elif 'code' in test_df.columns:
        test_codes = test_df['code'].tolist()
    else:
        raise ValueError("Test corpus must have either 'tag' or 'code' column")
    
    # Evaluate Zipf distributions
    evaluator = ZipfDistributionEvaluator()
    results = evaluator.compare_zipf_distributions(train_codes, test_codes)
    
    # Save results
    with open(args.output, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Display results
    print(f"\n🎯 ZIPF DISTRIBUTION EVALUATION RESULTS")
    print("=" * 60)
    
    train_params = results['zipf_parameters']['train']
    test_params = results['zipf_parameters']['test']
    
    print(f"📊 TRAIN DATASET:")
    print(f"   Zipf exponent (s): {train_params['s_exponent']:.4f}")
    print(f"   R-squared: {train_params['r_squared']:.4f}")
    print(f"   Total codes: {train_params['total_codes']}")
    print(f"   Unique codes: {train_params['unique_codes']}")
    
    print(f"\n📊 TEST DATASET:")
    print(f"   Zipf exponent (s): {test_params['s_exponent']:.4f}")
    print(f"   R-squared: {test_params['r_squared']:.4f}")
    print(f"   Total codes: {test_params['total_codes']}")
    print(f"   Unique codes: {test_params['unique_codes']}")
    
    print(f"\n📊 DISTRIBUTION DISTANCES:")
    distances = results['distribution_distances']
    print(f"   Jensen-Shannon Divergence: {distances['jensen_shannon_divergence']:.4f}")
    print(f"   Kullback-Leibler Divergence: {distances['kullback_leibler_divergence']:.4f}")
    print(f"   Chi-square Distance: {distances['chi_square_distance']:.4f}")
    print(f"   Earth Mover's Distance: {distances['earth_mover_distance']:.4f}")
    
    print(f"\n📊 INTERPRETATION:")
    interpretation = results['interpretation']
    print(f"   Distribution Similarity: {interpretation['distribution_similarity']}")
    print(f"   Zipf Adherence: {interpretation['zipf_adherence']}")
    print(f"   Exponent Similarity: {interpretation['exponent_similarity']}")
    
    print(f"\n💾 Results saved to {args.output}")

if __name__ == "__main__":
    main()
