#!/usr/bin/env python3
"""
Reusability Calculator for Schema Induction Pipeline

This script calculates the reusability metric by:
1. Running the build_corpus pipeline on test data
2. Collecting all unique codes from all datapoints
3. Computing the reusability metric as the total number of unique codes

Usage:
    python calculate_reusability.py --test_data path/to/test.csv --output results.json
"""

import asyncio
import argparse
import json
import sys
import pandas as pd
from pathlib import Path

# Add current directory to path for imports
sys.path.append('.')

from build_corpus import OptimizedTestInferenceCorpusBuilder

async def calculate_reusability(test_data_path: str = None, question: str = None, 
                              train_corpus_path: str = None, hierarchical_tree_path: str = None,
                              output_path: str = None, max_datapoints: int = None, test_corpus_path: str = None):
    """
    Calculate reusability metric for the schema induction pipeline.
    
    Args:
        test_data_path: Path to test CSV file
        question: Research question
        train_corpus_path: Path to training corpus parquet file
        hierarchical_tree_path: Path to hierarchical tree JSON file
        output_path: Path to save results (optional)
        max_datapoints: Maximum number of datapoints to process (optional)
    
    Returns:
        Dictionary with reusability metrics
    """
    print('🔄 Calculating Reusability Metric')
    print('=' * 60)
    
    # FAST MODE: Use existing test corpus if provided
    if test_corpus_path:
        print(f'⚡ FAST MODE: Loading test corpus from {test_corpus_path}')
        test_df = pd.read_parquet(test_corpus_path)
        test_codes = set(test_df['code'].tolist())
        print(f'📊 Loaded {len(test_codes)} test codes from existing corpus')
        
        # Load training corpus
        print(f'📂 Loading training corpus from {train_corpus_path}')
        train_df = pd.read_parquet(train_corpus_path)
        train_codes = set(train_df['tag'].tolist()) if 'tag' in train_df.columns else set(train_df.iloc[:, 0].tolist())
        
        # Calculate reusability - CORRECTED FORMULA
        # Reusability = codes_in_test_that_exist_in_train / total_codes_in_train
        total_unique_codes = len(test_codes)
        codes_in_train = len(test_codes.intersection(train_codes))
        unique_train_codes = len(train_codes)
        reusability_metric = codes_in_train / unique_train_codes if unique_train_codes > 0 else 0
        
        # Generate chunk_to_codes_mapping from test corpus
        chunk_to_codes = {}
        all_unique_codes = set()
        
        # Group codes by datapoint and create chunk mappings
        for datapoint_id in test_df['datapoint'].unique():
            datapoint_codes = test_df[test_df['datapoint'] == datapoint_id]['code'].tolist()
            if datapoint_codes:
                # Create a single chunk for this datapoint (since we don't have individual chunk info)
                chunk_id = f'datapoint_{datapoint_id}_chunk_0'
                chunk_to_codes[chunk_id] = datapoint_codes
                all_unique_codes.update(datapoint_codes)
        
        results = {
            'reusability_metric': reusability_metric,
            'total_unique_codes': total_unique_codes,
            'codes_in_train': codes_in_train,
            'unique_train_codes': len(train_codes),
            'test_codes': list(test_codes),
            'chunk_to_codes_mapping': chunk_to_codes,
            'all_unique_codes': list(all_unique_codes),
            'unique_codes': list(all_unique_codes)
        }
        
        if output_path:
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=2)
        
        return results
    
    # SLOW MODE: Build corpus from test data
    
    # Load test data
    print(f'📂 Loading test data from {test_data_path}')
    if test_data_path.endswith(".parquet"):
        df = pd.read_parquet(test_data_path)
    else:
        df = pd.read_csv(test_data_path)
    
    # Get text column (assume 'text' or first column)
    if 'text' in df.columns:
        test_datapoints = df['text'].tolist()
    else:
        test_datapoints = df.iloc[:, 0].tolist()
    
    if max_datapoints:
        test_datapoints = test_datapoints[:max_datapoints]
        print(f'📊 Processing {len(test_datapoints)} test datapoints (limited from {len(df)})')
    else:
        print(f'📊 Processing {len(test_datapoints)} test datapoints')
    
    # Initialize builder
    print(f'🔧 Initializing corpus builder...')
    builder = OptimizedTestInferenceCorpusBuilder(
        question=question,
        train_corpus_path=train_corpus_path,
        hierarchical_tree_path=hierarchical_tree_path
    )
    
    # Process datapoints
    print(f'⚡ Processing datapoints...')
    results, code_embeddings = await builder.process_datapoints(test_datapoints)
    
    print(f'\n🎯 Reusability Calculation:')
    print('=' * 60)
    
    # Calculate reusability from results - FIXED VERSION
    chunk_to_codes = {}
    all_unique_codes = set()
    datapoint_stats = []
    
    for i, result in enumerate(results):
        datapoint_id = f'datapoint_{i}'
        codes = result.get('codes', [])
        num_chunks = result.get('chunks', 1)
        
        # Distribute codes across chunks (since we don't have individual chunk results)
        # We'll create multiple chunk entries based on the number of chunks processed
        codes_per_chunk = len(codes) // num_chunks if num_chunks > 0 else len(codes)
        
        for chunk_idx in range(num_chunks):
            chunk_id = f'{datapoint_id}_chunk_{chunk_idx}'
            
            # Calculate start and end indices for this chunk
            start_idx = chunk_idx * codes_per_chunk
            if chunk_idx == num_chunks - 1:  # Last chunk gets remaining codes
                end_idx = len(codes)
            else:
                end_idx = start_idx + codes_per_chunk
            
            chunk_codes = []
            for code_info in codes[start_idx:end_idx]:
                if code_info and code_info.get('code'):
                    code = code_info['code']
                    chunk_codes.append(code)
                    all_unique_codes.add(code)
            
            if chunk_codes:
                chunk_to_codes[chunk_id] = chunk_codes
        
        # Store datapoint statistics
        datapoint_stats.append({
            'datapoint_id': datapoint_id,
            'total_codes': len(codes),
            'unique_codes': len([c for c in codes if c and c.get('code')]),
            'success': result.get('success', False),
            'chunks_processed': num_chunks
        })
    
    # Calculate reusability metrics - CORRECTED VERSION
    # Get total number of codes in training corpus
    unique_train_codes = len(set(builder.train_codes))
    
    # Get training codes as a set for efficient lookup
    train_codes_set = set(builder.train_codes)
    
    # Count how many test codes exist in training corpus
    codes_in_train = len([code for code in all_unique_codes if code in train_codes_set])
    
    # Calculate reusability ratio: codes_in_test_that_exist_in_train / total_codes_in_train
    reusability_ratio = codes_in_train / unique_train_codes if unique_train_codes > 0 else 0
    
    total_unique_codes = len(all_unique_codes)
    total_chunks = len(chunk_to_codes)
    successful_datapoints = sum(1 for stat in datapoint_stats if stat['success'])
    
    # Create results dictionary
    reusability_results = {
        'reusability_metric': reusability_ratio,
        'unique_train_codes': unique_train_codes,
        'codes_in_train': codes_in_train,
        'reusability_ratio': reusability_ratio,
        'total_datapoints': len(results),
        'successful_datapoints': successful_datapoints,
        'total_chunks': total_chunks,
        'total_unique_codes': total_unique_codes,
        'all_unique_codes': list(all_unique_codes),
            'unique_codes': list(all_unique_codes),
        'code_embeddings': code_embeddings if 'code_embeddings' in locals() else {},
        'chunk_to_codes_mapping': chunk_to_codes,
        'datapoint_statistics': datapoint_stats,
        'processing_stats': {
            'total_chunks': builder.stats.total_chunks,
            'open_coding_requests': builder.stats.open_coding_requests,
            'replacement_requests': builder.stats.replacement_requests,
            'parent_retrieval_requests': builder.stats.parent_retrieval_requests,
            'embedding_requests': builder.stats.embedding_requests,
            'codes_generated': builder.stats.codes_generated,
            'codes_replaced': builder.stats.codes_replaced,
            'codes_discarded': builder.stats.codes_discarded,
            'success_rate': builder.stats.codes_replaced / builder.stats.codes_generated * 100 if builder.stats.codes_generated > 0 else 0
        }
    }
    
    # Print summary
    print(f'📊 Reusability Summary:')
    print(f'   Total datapoints processed: {len(results)}')
    print(f'   Successful datapoints: {successful_datapoints}')
    print(f'   Total chunks: {total_chunks}')
    print(f'   Total unique codes: {total_unique_codes}')
    print(f'   Unique training codes: {unique_train_codes}')
    print(f'   Test codes found in training: {codes_in_train}')
    print(f'   Success rate: {builder.stats.codes_replaced / builder.stats.codes_generated * 100:.1f}%' if builder.stats.codes_generated > 0 else '   Success rate: 0.0%')
    
    print(f'\n🎯 REUSABILITY METRIC: {reusability_ratio:.4f}')
    print(f'   This represents the ratio of test codes that exist in the training corpus divided by total training codes')
    
    # Save results if output path provided
    if output_path:
        print(f'\n💾 Saving results to {output_path}')
        with open(output_path, 'w') as f:
            json.dump(reusability_results, f, indent=2)
        print(f'✅ Results saved successfully!')
    
    return reusability_results

async def main():
    parser = argparse.ArgumentParser(description='Calculate reusability metric for schema induction pipeline')
    parser.add_argument('--test_data', type=str, help='Path to test data CSV file (optional if --test_corpus provided)')
    parser.add_argument('--question', type=str, required=True, help='Research question')
    parser.add_argument('--train_corpus', type=str, required=True, help='Path to training corpus parquet file')
    parser.add_argument('--hierarchical_tree', type=str, default=None, help='Path to hierarchical tree JSON file (optional)')
    parser.add_argument('--test_corpus', type=str, help='Path to existing test corpus parquet file (for fast mode)')
    parser.add_argument('--output', type=str, help='Path to save results JSON file')
    parser.add_argument('--max_datapoints', type=int, help='Maximum number of datapoints to process')
    
    args = parser.parse_args()
    
    # Validate arguments
    if not args.test_corpus and not args.test_data:
        parser.error('Either --test_data or --test_corpus must be provided')
    if args.test_corpus and args.test_data:
        parser.error('Provide either --test_data or --test_corpus, not both')
    
    # Run reusability calculation
    results = await calculate_reusability(
        test_data_path=args.test_data,
        question=args.question,
        train_corpus_path=args.train_corpus,
        hierarchical_tree_path=args.hierarchical_tree,
        output_path=args.output,
        max_datapoints=args.max_datapoints,
        test_corpus_path=args.test_corpus
    )
    
    print(f'\n✅ Reusability calculation completed!')
    print(f'   Final reusability metric: {results["reusability_metric"]}')

if __name__ == '__main__':
    asyncio.run(main())
