#!/usr/bin/env python3
"""
Label Flipping Processing Utility

This module handles the conversion from 4-label NLI system to 3-label conflict detection system:
- 4 labels: A_into_B, B_into_A, mutual, not_mergeable
- 3 labels: IMPLIES (A_into_B), MUTUAL, CONTRADICTS (not_mergeable)

The main function flips B_into_A relationships to A_into_B by swapping code_a and code_b.
"""

from typing import List, Dict, Any
import logging

logger = logging.getLogger(__name__)

def flip_b_into_a_labels(nli_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Flip B_into_A labels to A_into_B by swapping code_a and code_b
    
    Args:
        nli_results: List of NLI classification results with format:
            {
                'code_a': str,
                'code_b': str, 
                'nli_label': str,  # 'A_into_B', 'B_into_A', 'mutual', 'not_mergeable'
                'similarity': float,
                ... (other fields)
            }
    
    Returns:
        List of processed NLI results with B_into_A converted to A_into_B
    """
    print("🔄 Processing label flipping: B_into_A → A_into_B")
    
    flipped_results = []
    flipped_count = 0
    
    for result in nli_results:
        if result['nli_label'] == 'B_into_A':
            # Flip the relationship by swapping code_a and code_b
            flipped_result = result.copy()
            flipped_result['code_a'] = result['code_b']
            flipped_result['code_b'] = result['code_a']
            flipped_result['nli_label'] = 'A_into_B'
            flipped_results.append(flipped_result)
            flipped_count += 1
        else:
            # Keep other labels as-is
            flipped_results.append(result)
    
    print(f"   Flipped {flipped_count} B_into_A relationships to A_into_B")
    print(f"   Total relationships: {len(flipped_results)}")
    
    # Count label distribution
    label_counts = {}
    for result in flipped_results:
        label = result['nli_label']
        label_counts[label] = label_counts.get(label, 0) + 1
    
    print(f"   Label distribution after flipping:")
    for label, count in label_counts.items():
        print(f"     {label}: {count}")
    
    return flipped_results

def validate_label_conversion(nli_results: List[Dict[str, Any]]) -> bool:
    """
    Validate that all labels are compatible with 3-label system
    
    Args:
        nli_results: List of NLI results to validate
    
    Returns:
        True if all labels are valid, False otherwise
    """
    valid_labels = {'A_into_B', 'mutual', 'not_mergeable'}
    found_labels = set()
    
    for result in nli_results:
        label = result['nli_label']
        found_labels.add(label)
        
        if label not in valid_labels:
            logger.error(f"Invalid label found: {label}")
            return False
    
    logger.info(f"Valid labels found: {found_labels}")
    return True

def process_nli_results_for_conflict_detection(nli_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Complete processing pipeline for NLI results to make them compatible with advanced conflict detection
    
    Args:
        nli_results: Raw NLI classification results
    
    Returns:
        Processed NLI results ready for advanced conflict detection, sorted by confidence (highest to lowest)
    """
    print("🔧 Processing NLI results for advanced conflict detection...")
    
    # Step 1: Sort by confidence (highest to lowest) to ensure highest confidence relationships are processed first
    print("📊 Sorting NLI results by confidence (highest to lowest)...")
    sorted_results = sorted(nli_results, key=lambda x: x.get('confidence', 0.0), reverse=True)
    
    # Show confidence distribution
    confidences = [r.get('confidence', 0.0) for r in sorted_results]
    print(f"   Confidence range: {min(confidences):.3f} - {max(confidences):.3f}")
    print(f"   Top 5 confidence scores: {confidences[:5]}")
    print(f"   Bottom 5 confidence scores: {confidences[-5:]}")
    
    # Step 2: Flip B_into_A labels
    flipped_results = flip_b_into_a_labels(sorted_results)
    
    # Step 3: Validate label conversion
    if not validate_label_conversion(flipped_results):
        raise ValueError("Invalid labels found after conversion")
    
    print("✅ NLI results processed successfully for advanced conflict detection")
    print(f"   📊 Processed {len(flipped_results)} relationships sorted by confidence")
    return flipped_results

def get_label_mapping_info() -> Dict[str, str]:
    """
    Get information about the label mapping between 4-label and 3-label systems
    
    Returns:
        Dictionary mapping 4-label system to 3-label system
    """
    return {
        'A_into_B': 'IMPLIES',      # Direct mapping
        'B_into_A': 'IMPLIES',      # Flipped to A_into_B then mapped
        'mutual': 'MUTUAL',         # Direct mapping  
        'not_mergeable': 'CONTRADICTS'  # Direct mapping
    }

if __name__ == "__main__":
    # Test the label flipping functionality
    test_nli_results = [
        {'code_a': 'A', 'code_b': 'B', 'nli_label': 'A_into_B', 'similarity': 0.8},
        {'code_a': 'C', 'code_b': 'D', 'nli_label': 'B_into_A', 'similarity': 0.7},
        {'code_a': 'E', 'code_b': 'F', 'nli_label': 'mutual', 'similarity': 0.9},
        {'code_a': 'G', 'code_b': 'H', 'nli_label': 'not_mergeable', 'similarity': 0.6},
        {'code_a': 'I', 'code_b': 'J', 'nli_label': 'B_into_A', 'similarity': 0.85}
    ]
    
    print("🧪 Testing Label Flipping Processing")
    print("=" * 50)
    
    print("Original NLI results:")
    for result in test_nli_results:
        print(f"  {result['code_a']} {result['nli_label']} {result['code_b']}")
    
    print("\nProcessing...")
    processed_results = process_nli_results_for_conflict_detection(test_nli_results)
    
    print("\nProcessed NLI results:")
    for result in processed_results:
        print(f"  {result['code_a']} {result['nli_label']} {result['code_b']}")
    
    print(f"\nLabel mapping info:")
    mapping = get_label_mapping_info()
    for old_label, new_label in mapping.items():
        print(f"  {old_label} → {new_label}") 