#!/usr/bin/env python3
"""
Consistency and Stability Evaluation using Jensen-Shannon Divergence (JSD)
Measures distribution similarity between TRAINING and TESTING sets

Consistency = JSD (how different train vs test distributions are)
Stability = 1 - JSD (how similar train vs test distributions are)

This measures how well the schema generalizes to unseen data.
"""

import json
import pandas as pd
import numpy as np
from scipy.spatial.distance import jensenshannon
from typing import Dict, List, Any, Tuple
import argparse
import os
from zipf_distribution_eval import ZipfDistributionEvaluator

class ConsistencyStabilityEvaluator:
    """Evaluates consistency and stability between train and test distributions using JSD divergence"""
    
    def __init__(self):
        pass
    
    def calculate_jsd(self, p_dist: Dict[str, float], q_dist: Dict[str, float]) -> float:
        """Calculate Jensen-Shannon Divergence between two distributions using SciPy.
        
        Uses scipy.spatial.distance.jensenshannon for accurate and robust calculation.
        
        Args:
            p_dist: First probability distribution (e.g., train)
            q_dist: Second probability distribution (e.g., test)
            
        Returns:
            JSD value (0 = identical, 1 = completely different)
        """
        # Get all unique codes
        all_codes = set(p_dist.keys()) | set(q_dist.keys())
        
        if not all_codes:
            return 0.0
        
        # Create probability vectors
        p_vec = np.array([p_dist.get(code, 0.0) for code in all_codes])
        q_vec = np.array([q_dist.get(code, 0.0) for code in all_codes])
        
        # Normalize to ensure they are proper probability distributions
        p_vec = p_vec / np.sum(p_vec) if np.sum(p_vec) > 0 else p_vec
        q_vec = q_vec / np.sum(q_vec) if np.sum(q_vec) > 0 else q_vec
        
        # Use SciPy's jensenshannon function (base=2 by default)
        jsd = jensenshannon(p_vec, q_vec)
        
        return jsd
    
    def create_zipf_distribution(self, codes: List[str]) -> Dict[str, float]:
        """Create Zipf-based probability distribution from list of codes"""
        if not codes:
            return {}
        
        # Use ZipfDistributionEvaluator to create theoretical Zipf distribution
        zipf_evaluator = ZipfDistributionEvaluator()
        
        # Create frequency distribution
        from collections import Counter
        freq_dist = dict(Counter(codes))
        
        # Calculate Zipf parameters
        sorted_freq = sorted(freq_dist.values(), reverse=True)
        s_exponent, C_constant, r_squared = zipf_evaluator.calculate_zipf_parameters(sorted_freq)
        
        # Create theoretical Zipf distribution
        zipf_dist = zipf_evaluator.calculate_zipf_distribution(freq_dist, s_exponent)
        
        return zipf_dist
    
    def load_train_codes(self, train_corpus_path: str) -> List[str]:
        """Load training codes from corpus parquet file"""
        print(f"📂 Loading training codes from {train_corpus_path}")
        
        try:
            df = pd.read_parquet(train_corpus_path)
            print(f"📊 Loaded {len(df)} training records")
            
            # Extract codes from the 'tag' column (which contains the codes)
            train_codes = []
            for _, row in df.iterrows():
                if 'tag' in row and pd.notna(row['tag']):
                    tag = str(row['tag']).strip()
                    if tag:  # Only add non-empty tags
                        train_codes.append(tag)
            
            print(f"📊 Total training codes: {len(train_codes)}")
            print(f"📊 Unique training codes: {len(set(train_codes))}")
            return train_codes
            
        except Exception as e:
            print(f"❌ Error loading training corpus: {e}")
            return []
    
    def load_test_codes_from_corpus(self, test_corpus_path: str) -> List[str]:
        """Load test codes from generated corpus parquet file"""
        print(f"📂 Loading test codes from {test_corpus_path}")
        
        try:
            df = pd.read_parquet(test_corpus_path)
            print(f"📊 Loaded {len(df)} test records")
            
            # Extract codes from either 'code' or 'tag' column
            test_codes = []
            for _, row in df.iterrows():
                # Try 'code' column first, then 'tag' column
                if 'code' in row and pd.notna(row['code']):
                    code = str(row['code']).strip()
                    if code:  # Only add non-empty codes
                        test_codes.append(code)
                elif 'tag' in row and pd.notna(row['tag']):
                    code = str(row['tag']).strip()
                    if code:  # Only add non-empty codes
                        test_codes.append(code)
            
            print(f"📊 Total test codes: {len(test_codes)}")
            print(f"📊 Unique test codes: {len(set(test_codes))}")
            return test_codes
            
        except Exception as e:
            print(f"❌ Error loading test corpus: {e}")
            return []
    
    def load_test_codes_from_results(self, results_file: str) -> List[str]:
        """Load test codes from reusability results (fallback method)"""
        print(f"📂 Loading test codes from {results_file}")
        
        with open(results_file, 'r') as f:
            results = json.load(f)
        
        # Extract codes from chunk mapping
        chunk_mapping = results.get('chunk_to_codes_mapping', {})
        if not chunk_mapping:
            print("❌ No chunk_to_codes_mapping found")
            return []
        
        # Get all codes from all chunks
        test_codes = []
        for chunk_id, codes in chunk_mapping.items():
            test_codes.extend(codes)
        
        print(f"📊 Total test codes: {len(test_codes)}")
        print(f"📊 Unique test codes: {len(set(test_codes))}")
        return test_codes
    
    def evaluate_consistency_stability(self, test_data_path: str, train_corpus_path: str) -> Dict[str, Any]:
        """Evaluate consistency and stability between train and test distributions"""
        
        # Load training codes
        train_codes = self.load_train_codes(train_corpus_path)
        
        # Try to load test codes from corpus first, then fallback to results
        test_codes = []
        if test_data_path.endswith('.parquet'):
            test_codes = self.load_test_codes_from_corpus(test_data_path)
        else:
            test_codes = self.load_test_codes_from_results(test_data_path)
        
        if not train_codes or not test_codes:
            print("❌ Missing train or test codes")
            return {}
        
        # Create Zipf-based distributions
        train_dist = self.create_zipf_distribution(train_codes)
        test_dist = self.create_zipf_distribution(test_codes)
        
        print(f"📊 Training distribution: {len(train_dist)} unique codes")
        print(f"📊 Test distribution: {len(test_dist)} unique codes")
        
        # Calculate JSD between train and test
        jsd = self.calculate_jsd(train_dist, test_dist)
        
        # Calculate consistency and stability
        consistency = jsd  # Consistency = JSD (how different distributions are)
        stability = 1 - jsd  # Stability = 1 - JSD (how similar distributions are)
        
        # Calculate overlap metrics
        train_codes_set = set(train_codes)
        test_codes_set = set(test_codes)
        
        # Codes that appear in both train and test
        common_codes = train_codes_set & test_codes_set
        
        # Codes only in train
        train_only_codes = train_codes_set - test_codes_set
        
        # Codes only in test
        test_only_codes = test_codes_set - train_codes_set
        
        # Calculate overlap percentages
        train_coverage = len(common_codes) / len(train_codes_set) if train_codes_set else 0
        test_coverage = len(common_codes) / len(test_codes_set) if test_codes_set else 0
        
        results = {
            'consistency_stability_metrics': {
                # JSD metrics
                'jsd': float(jsd),
                
                # Consistency and Stability
                'consistency': float(consistency),  # Consistency = JSD
                'stability': float(stability),     # Stability = 1 - JSD
                
                # Code overlap metrics
                'total_train_codes': len(train_codes),
                'unique_train_codes': len(train_codes_set),
                'total_test_codes': len(test_codes),
                'unique_test_codes': len(test_codes_set),
                'common_codes': len(common_codes),
                'train_only_codes': len(train_only_codes),
                'test_only_codes': len(test_only_codes),
                'train_coverage': float(train_coverage),
                'test_coverage': float(test_coverage),
                
                # Distribution details
                'train_distribution': train_dist,
                'test_distribution': test_dist,
                'common_codes_list': list(common_codes),
                'train_only_codes_list': list(train_only_codes),
                'test_only_codes_list': list(test_only_codes)
            },
            'interpretation': self._interpret_consistency_stability_results(consistency, stability, train_coverage, test_coverage)
        }
        
        return results
    
    def _interpret_consistency_stability_results(self, consistency: float, stability: float, train_coverage: float, test_coverage: float) -> Dict[str, str]:
        """Interpret consistency and stability results for train vs test"""
        interpretation = {}
        
        # Consistency interpretation (Consistency = JSD)
        if consistency < 0.1:
            interpretation['consistency'] = "Very consistent (train and test distributions are very similar)"
        elif consistency < 0.3:
            interpretation['consistency'] = "Moderately consistent (train and test distributions are somewhat similar)"
        elif consistency < 0.5:
            interpretation['consistency'] = "Moderately inconsistent (train and test distributions differ somewhat)"
        elif consistency < 0.7:
            interpretation['consistency'] = "Inconsistent (train and test distributions are quite different)"
        else:
            interpretation['consistency'] = "Very inconsistent (train and test distributions are very different)"
        
        # Stability interpretation (Stability = 1 - JSD)
        if stability > 0.9:
            interpretation['stability'] = "Very stable (excellent generalization to test data)"
        elif stability > 0.7:
            interpretation['stability'] = "Stable (good generalization to test data)"
        elif stability > 0.5:
            interpretation['stability'] = "Moderately stable (moderate generalization to test data)"
        elif stability > 0.3:
            interpretation['stability'] = "Unstable (poor generalization to test data)"
        else:
            interpretation['stability'] = "Very unstable (very poor generalization to test data)"
        
        # Generalization interpretation
        if train_coverage > 0.8 and test_coverage > 0.8:
            interpretation['generalization'] = "Excellent generalization (high code overlap between train and test)"
        elif train_coverage > 0.6 and test_coverage > 0.6:
            interpretation['generalization'] = "Good generalization (moderate code overlap between train and test)"
        elif train_coverage > 0.4 and test_coverage > 0.4:
            interpretation['generalization'] = "Moderate generalization (some code overlap between train and test)"
        else:
            interpretation['generalization'] = "Poor generalization (low code overlap between train and test)"
        
        return interpretation

def main():
    parser = argparse.ArgumentParser(description='Evaluate consistency and stability between train and test distributions using JSD')
    parser.add_argument('--test_data', type=str, required=True, 
                       help='Path to test data (corpus parquet file or results JSON file)')
    parser.add_argument('--train_corpus', type=str, required=True,
                       help='Path to training corpus parquet file')
    parser.add_argument('--output', type=str, default='consistency_stability_results.json',
                       help='Output file path')
    
    args = parser.parse_args()
    
    evaluator = ConsistencyStabilityEvaluator()
    results = evaluator.evaluate_consistency_stability(args.test_data, args.train_corpus)
    
    if results:
        # Save results
        with open(args.output, 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"\n🎯 TRAIN vs TEST CONSISTENCY & STABILITY EVALUATION")
        print("=" * 60)
        
        metrics = results['consistency_stability_metrics']
        interpretation = results['interpretation']
        
        print(f"📊 JSD METRICS:")
        print(f"   JSD (Train vs Test): {metrics['jsd']:.4f}")
        
        print(f"\n📊 CONSISTENCY & STABILITY:")
        print(f"   Consistency: {metrics['consistency']:.4f} (Consistency = JSD)")
        print(f"   Stability: {metrics['stability']:.4f} (Stability = 1 - JSD)")
        
        print(f"\n📊 CODE OVERLAP METRICS:")
        print(f"   Total Train Codes: {metrics['total_train_codes']}")
        print(f"   Unique Train Codes: {metrics['unique_train_codes']}")
        print(f"   Total Test Codes: {metrics['total_test_codes']}")
        print(f"   Unique Test Codes: {metrics['unique_test_codes']}")
        print(f"   Common Codes: {metrics['common_codes']}")
        print(f"   Train Only Codes: {metrics['train_only_codes']}")
        print(f"   Test Only Codes: {metrics['test_only_codes']}")
        print(f"   Train Coverage: {metrics['train_coverage']:.3f} ({metrics['train_coverage']*100:.1f}%)")
        print(f"   Test Coverage: {metrics['test_coverage']:.3f} ({metrics['test_coverage']*100:.1f}%)")
        
        print(f"\n📊 INTERPRETATION:")
        print(f"   Consistency: {interpretation['consistency']}")
        print(f"   Stability: {interpretation['stability']}")
        print(f"   Generalization: {interpretation['generalization']}")
        
        print(f"\n💾 Results saved to {args.output}")
    else:
        print("❌ No results generated")

if __name__ == "__main__":
    main()
