#!/usr/bin/env python3
"""
Advanced conflict detection and resolution with multiple merging strategies
"""

import os
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple
import logging

from .conflict_detection import OptimizedSparseMatrix, RelationshipType, ConflictResolutionStrategy
from .merging_strategies import MergingStrategy, MergingStrategyFactory

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def detect_and_resolve_conflicts_advanced(nli_results: List[Dict[str, Any]], 
                                        output_dir: str, 
                                        corpus_df: pd.DataFrame,
                                        merging_strategy: MergingStrategy = MergingStrategy.INHERITANCE_MERGING,
                                        conflict_resolution_strategy: ConflictResolutionStrategy = ConflictResolutionStrategy.OLD_REJECTION_BASED,
                                        **strategy_kwargs) -> Tuple[np.ndarray, List[str], Dict[str, List[str]]]:
    """
    Advanced conflict detection and resolution with configurable merging strategies
    
    Args:
        nli_results: List of NLI classification results
        output_dir: Directory to save results
        corpus_df: Corpus dataframe with 'tag' and 'chunk_text' columns
        merging_strategy: Strategy for merging mutual relationships
        conflict_resolution_strategy: Strategy for resolving conflicts during inference (default: OLD_REJECTION_BASED)
        **strategy_kwargs: Additional arguments for the merging strategy
    
    Returns:
        Tuple of (relationship_matrix, unique_codes, code_to_datapoints)
    """
    print("🚀 Running Advanced Sparse Scalable Conflict Detection...")
    
    # Build datapoint mapping
    print("   Built datapoint mapping for", len(corpus_df), "codes")
    code_to_datapoints = {}
    for _, row in corpus_df.iterrows():
        code = row['tag']
        if code not in code_to_datapoints:
            code_to_datapoints[code] = []
        code_to_datapoints[code].append(row['chunk_text'])
    
    # Extract unique codes from NLI results
    unique_codes = set()
    for result in nli_results:
        unique_codes.add(result['code_a'])
        unique_codes.add(result['code_b'])
    
    unique_codes = sorted(list(unique_codes))
    print("   Processing", len(unique_codes), "unique codes")
    
    # Create sparse matrix with specified conflict resolution strategy
    matrix = OptimizedSparseMatrix(unique_codes, conflict_resolution_strategy=conflict_resolution_strategy)
    
    # Add relationships from NLI results
    for result in nli_results:
        code_a = result['code_a']
        code_b = result['code_b']
        relationship = result['nli_label']
        confidence = result.get('confidence', 0.8)
        
        matrix.add_direct_relationship(code_a, code_b, relationship, confidence)
    
    # Run inference
    matrix.run_optimized_inference(timeout_seconds=300)
    
    # Get performance stats
    stats = matrix.get_performance_stats()
    print("   Advanced detection completed:")
    print("     - Conflicts detected:", len(matrix.conflicts))
    print("     - Inferences made:", stats['inference_count'])
    print("     - Relationships:", stats['relationship_count'])
    print("     - Memory usage:", f"{stats['memory_usage_mb']:.1f} MB")
    
    # Convert to numpy matrix
    n_codes = len(unique_codes)
    relationship_matrix = np.zeros((n_codes, n_codes), dtype=int)
    
    # Fill matrix from sparse matrix
    for i, code_i in enumerate(unique_codes):
        for j, code_j in enumerate(unique_codes):
            if i != j:
                rel_type = matrix._get_relationship(i, j)
                if rel_type != RelationshipType.EMPTY:
                    relationship_matrix[i, j] = rel_type.value
    
    # Analyze relationship matrix
    print("\n📊 Relationship Matrix Analysis:")
    print(f"   Matrix size: {relationship_matrix.shape}")
    print(f"   Total relationships: {np.sum(relationship_matrix > 0)}")
    print(f"   IMPLIES (→): {np.sum(relationship_matrix == 1)}")
    print(f"   MUTUAL (↔): {np.sum(relationship_matrix == 2)}")
    print(f"   CONTRADICTS (×): {np.sum(relationship_matrix == 3)}")
    print(f"   Empty relationships: {np.sum(relationship_matrix == 0)}")
    
    # Calculate relationship density
    total_possible = len(unique_codes) * (len(unique_codes) - 1)  # Exclude self-loops
    density = np.sum(relationship_matrix > 0) / total_possible if total_possible > 0 else 0
    print(f"   Relationship density: {density:.4f} ({density*100:.2f}%)")
    
    # SKIP MERGING IN CONFLICT DETECTION - let graph construction handle all merging
    print(f"\n⏭️  Skipping merging in conflict detection (graph construction will handle merging)")
    merged_matrix = relationship_matrix
    merged_codes = unique_codes
    code_mapping = {code: code for code in unique_codes}  # No mapping needed
    merged_code_to_datapoints = code_to_datapoints
    
    # Save results
    print("💾 Saving advanced conflict detection results...")
    os.makedirs(output_dir, exist_ok=True)
    
    # Save conflicts
    conflicts_df = pd.DataFrame(matrix.conflicts)
    conflicts_path = os.path.join(output_dir, "advanced_conflicts.parquet")
    conflicts_df.to_parquet(conflicts_path)
    print("   Saved conflicts:", conflicts_path)
    
    # Save performance stats
    stats_df = pd.DataFrame([stats])
    stats_path = os.path.join(output_dir, "advanced_performance_stats.parquet")
    stats_df.to_parquet(stats_path)
    print("   Saved performance stats:", stats_path)
    
    # Save relationship matrix
    matrix_path = os.path.join(output_dir, "final_relationship_matrix.npy")
    np.save(matrix_path, relationship_matrix)
    print("   Saved relationship matrix:", matrix_path)
    
    # Save unique codes
    codes_df = pd.DataFrame({'code': unique_codes})
    codes_path = os.path.join(output_dir, "unique_codes.parquet")
    codes_df.to_parquet(codes_path)
    print("   Saved unique codes:", codes_path)
    
    # Save relationship summary
    relationship_summary = {
        'original_codes': len(unique_codes),
        'merged_codes': len(merged_codes),
        'original_a_into_b': np.sum(relationship_matrix == 1),
        'merged_a_into_b': np.sum(merged_matrix == 1),
        'original_mutual': np.sum(relationship_matrix == 3) // 2,
        'merged_mutual': np.sum(merged_matrix == 3) // 2,
        'merging_strategy': merging_strategy.value,
        'strategy_kwargs': str(strategy_kwargs) if strategy_kwargs else '{}'
    }
    
    summary_df = pd.DataFrame([relationship_summary])
    summary_path = os.path.join(output_dir, "relationship_summary.parquet")
    summary_df.to_parquet(summary_path)
    print("   Saved relationship summary:", summary_path)
    
    return merged_matrix, merged_codes, merged_code_to_datapoints

def compare_all_strategies(nli_results: List[Dict[str, Any]], 
                          output_dir: str, 
                          corpus_df: pd.DataFrame) -> Dict[str, Any]:
    """
    Compare all merging strategies on the same data
    
    Args:
        nli_results: List of NLI classification results
        output_dir: Directory to save results
        corpus_df: Corpus dataframe
    
    Returns:
        Dictionary with results from all strategies
    """
    print("🔍 COMPARING ALL MERGING STRATEGIES")
    print("=" * 60)
    
    # First, run conflict detection to get the base matrix
    print("🚀 Running base conflict detection...")
    
    # Build datapoint mapping
    code_to_datapoints = {}
    for _, row in corpus_df.iterrows():
        code = row['tag']
        if code not in code_to_datapoints:
            code_to_datapoints[code] = []
        code_to_datapoints[code].append(row['chunk_text'])
    
    # Extract unique codes
    unique_codes = set()
    for result in nli_results:
        unique_codes.add(result['code_a'])
        unique_codes.add(result['code_b'])
    
    unique_codes = sorted(list(unique_codes))
    
    # Create sparse matrix with default conflict resolution strategy
    matrix = OptimizedSparseMatrix(unique_codes, conflict_resolution_strategy=ConflictResolutionStrategy.OLD_REJECTION_BASED)
    
    # Add relationships
    for result in nli_results:
        code_a = result['code_a']
        code_b = result['code_b']
        relationship = result['nli_label']
        confidence = result.get('confidence', 0.8)
        
        matrix.add_direct_relationship(code_a, code_b, relationship, confidence)
    
    # Run inference
    matrix.run_optimized_inference(timeout_seconds=300)
    
    # Convert to numpy matrix
    n_codes = len(unique_codes)
    relationship_matrix = np.zeros((n_codes, n_codes), dtype=int)
    
    for i, code_i in enumerate(unique_codes):
        for j, code_j in enumerate(unique_codes):
            if i != j:
                rel_type = matrix._get_relationship(i, j)
                if rel_type != RelationshipType.EMPTY:
                    relationship_matrix[i, j] = rel_type.value
    
    # Now compare all strategies
    strategies = [
        MergingStrategy.NO_MERGING,
        MergingStrategy.INHERITANCE_MERGING,
        MergingStrategy.LIMITED_MERGING
    ]
    
    results = {}
    
    for strategy_type in strategies:
        print(f"\n📊 Testing {strategy_type.value}:")
        
        # Create strategy
        if strategy_type == MergingStrategy.LIMITED_MERGING:
            strategy = MergingStrategyFactory.create_strategy(strategy_type, max_component_size=3)
        else:
            strategy = MergingStrategyFactory.create_strategy(strategy_type)
        
        # Apply merging
        new_matrix, new_codes, code_mapping = strategy.merge_nodes(relationship_matrix, unique_codes)
        
        # Count relationships
        a_into_b_count = np.sum(new_matrix == 1)
        mutual_count = np.sum(new_matrix == 3) // 2
        
        results[strategy_type.value] = {
            'original_codes': len(unique_codes),
            'final_codes': len(new_codes),
            'a_into_b_relationships': a_into_b_count,
            'mutual_relationships': mutual_count,
            'total_relationships': a_into_b_count + mutual_count,
            'code_mapping': code_mapping,
            'matrix': new_matrix,
            'codes': new_codes
        }
        
        print(f"   Original codes: {len(unique_codes)}")
        print(f"   Final codes: {len(new_codes)}")
        print(f"   A_into_B relationships: {a_into_b_count}")
        print(f"   Mutual relationships: {mutual_count}")
        print(f"   Total relationships: {a_into_b_count + mutual_count}")
    
    # Save comparison results
    comparison_dir = os.path.join(output_dir, "strategy_comparison")
    os.makedirs(comparison_dir, exist_ok=True)
    
    comparison_df = pd.DataFrame([
        {
            'strategy': strategy,
            'original_codes': data['original_codes'],
            'final_codes': data['final_codes'],
            'a_into_b_relationships': data['a_into_b_relationships'],
            'mutual_relationships': data['mutual_relationships'],
            'total_relationships': data['total_relationships']
        }
        for strategy, data in results.items()
    ])
    
    comparison_path = os.path.join(comparison_dir, "strategy_comparison.parquet")
    comparison_df.to_parquet(comparison_path)
    print(f"\n💾 Saved strategy comparison: {comparison_path}")
    
    return results 