#!/usr/bin/env python3
"""
Load all NLI results from all batch files and test conflict detection
"""

import os
import pandas as pd
import time
import sys
from typing import List, Dict, Any

# Add the utils directory to the path
sys.path.append('.')

from .conflict_detection import SparseScalableMatrix, RelationshipType

def load_all_nli_results() -> List[Dict[str, Any]]:
    """Load all NLI results from all batch files"""
    print("📂 Loading all NLI results...")
    
    nli_dir = "temp_files/nli_classify"
    if not os.path.exists(nli_dir):
        print(f"❌ NLI directory not found: {nli_dir}")
        return []
    
    # Get all NLI batch files
    nli_files = [f for f in os.listdir(nli_dir) if f.endswith('.parquet')]
    print(f"📊 Found {len(nli_files)} NLI batch files")
    
    all_results = []
    total_rows = 0
    high_conf_count = 0
    
    for i, filename in enumerate(nli_files):
        if i % 100 == 0:
            print(f"   Loading file {i+1}/{len(nli_files)}: {filename}")
        
        file_path = os.path.join(nli_dir, filename)
        try:
            df = pd.read_parquet(file_path)
            
            # Check if this is a valid NLI results file
            required_columns = ['code_a', 'code_b', 'nli_label', 'confidence', 'similarity']
            if not all(col in df.columns for col in required_columns):
                print(f"   ⚠️ Skipping {filename} - missing required columns")
                continue
                
            total_rows += len(df)
            
            # Filter for high confidence relationships
            high_conf_df = df[df['confidence'] >= 0.5]
            high_conf_count += len(high_conf_df)
            
            # Convert to list of dictionaries
            for _, row in high_conf_df.iterrows():
                result = {
                    'code_a': row['code_a'],
                    'code_b': row['code_b'],
                    'nli_label': row['nli_label'],
                    'confidence': row['confidence'],
                    'similarity': row['similarity']
                }
                all_results.append(result)
                
        except Exception as e:
            print(f"   ❌ Error reading {filename}: {e}")
    
    print(f"📊 Loaded {len(all_results)} high-confidence relationships")
    print(f"📊 Total rows processed: {total_rows}")
    print(f"📊 High confidence (≥0.5): {high_conf_count}/{total_rows} ({high_conf_count/total_rows*100:.1f}%)")
    
    # Count label distribution
    label_counts = {}
    for result in all_results:
        label = result['nli_label']
        label_counts[label] = label_counts.get(label, 0) + 1
    
    print(f"📊 Label distribution:")
    for label, count in label_counts.items():
        print(f"   {label}: {count}")
    
    return all_results

def test_conflict_detection_with_full_data(nli_results: List[Dict[str, Any]]):
    """Test conflict detection with the full NLI dataset"""
    print("\n🚀 Testing Conflict Detection with Full Dataset")
    print("=" * 60)
    
    if not nli_results:
        print("❌ No NLI results to process")
        return False
    
    # Get unique codes
    all_codes = set()
    for result in nli_results:
        all_codes.add(result['code_a'])
        all_codes.add(result['code_b'])
    
    codes_list = list(all_codes)
    print(f"📊 Found {len(codes_list)} unique codes")
    
    # Create conflict detection matrix
    print(f"🚀 Creating conflict detection matrix...")
    matrix = SparseScalableMatrix(codes_list, max_hops=8)
    
    # Add relationships
    print(f"📝 Adding relationships to matrix...")
    added_count = 0
    skipped_count = 0
    
    for i, result in enumerate(nli_results):
        if i % 10000 == 0:
            print(f"   Processing relationship {i+1}/{len(nli_results)}")
        
        try:
            # Map NLI labels to relationship types
            nli_to_rel = {
                'A_into_B': '→',
                'B_into_A': '←',  # We'll handle this as A→B with reversed codes
                'mutual': '↔',
                'not_mergeable': '×'
            }
            
            if result['nli_label'] in nli_to_rel:
                rel_type = nli_to_rel[result['nli_label']]
                
                if result['nli_label'] == 'B_into_A':
                    # Reverse the relationship
                    matrix.add_direct_relationship(result['code_b'], result['code_a'], '→', result['confidence'])
                else:
                    matrix.add_direct_relationship(result['code_a'], result['code_b'], rel_type, result['confidence'])
                
                added_count += 1
            else:
                skipped_count += 1
                
        except Exception as e:
            skipped_count += 1
            continue
    
    print(f"📊 Added {added_count} relationships, skipped {skipped_count}")
    print(f"📊 Matrix has {matrix.relationship_count} relationships")
    
    # Run inference
    print(f"\n🚀 Running conflict detection inference...")
    start_time = time.time()
    
    try:
        matrix.run_sparse_inference(timeout_seconds=300)  # 5 minutes timeout
        inference_time = time.time() - start_time
        print(f"✅ Inference completed in {inference_time:.2f}s")
    except Exception as e:
        print(f"❌ Inference failed: {e}")
        return False
    
    # Analyze results
    stats = matrix.get_performance_stats()
    print(f"\n📊 Performance Stats:")
    print(f"   Total time: {stats['total_time']:.2f}s")
    print(f"   Inference time: {stats['inference_time']:.2f}s")
    print(f"   Node count: {stats['node_count']}")
    print(f"   Relationship count: {stats['relationship_count']}")
    print(f"   Inferences: {stats['inference_count']}")
    print(f"   Conflicts: {stats['conflict_count']}")
    print(f"   Density: {stats['density']:.6f}")
    
    # Check if results make sense
    print(f"\n🔍 Analysis:")
    print(f"   Input relationships: {added_count}")
    print(f"   Final relationships: {stats['relationship_count']}")
    print(f"   Inferences generated: {stats['inference_count']}")
    print(f"   Conflicts detected: {stats['conflict_count']}")
    
    # Calculate ratios
    inference_ratio = stats['inference_count'] / added_count if added_count > 0 else 0
    conflict_ratio = stats['conflict_count'] / added_count if added_count > 0 else 0
    
    print(f"   Inference ratio: {inference_ratio:.2f}")
    print(f"   Conflict ratio: {conflict_ratio:.2f}")
    
    # Show sample conflicts
    conflicts = matrix.detect_conflicts()
    print(f"\n🔍 Sample Conflicts (first 10):")
    for i, conflict in enumerate(conflicts[:10]):
        print(f"   {i+1}. {conflict['node_a']} vs {conflict['node_b']}: {conflict['explanation']}")
    
    return True

if __name__ == "__main__":
    print("🔍 Full NLI Data Processing Test")
    print("=" * 60)
    
    # Load all NLI results
    nli_results = load_all_nli_results()
    
    if nli_results:
        # Test conflict detection with full data
        success = test_conflict_detection_with_full_data(nli_results)
        
        print("\n" + "=" * 60)
        print("📋 Test Summary:")
        print(f"   Full data processing: {'✅ PASS' if success else '❌ FAIL'}")
        
        if success:
            print("\n🎉 Full data processing completed successfully!")
        else:
            print("\n⚠️ Full data processing failed. Check the output above.")
    else:
        print("\n❌ No NLI results loaded. Check the NLI directory.") 