#!/usr/bin/env python3
"""
Merging strategies for conflict detection
"""

import numpy as np
import networkx as nx
from typing import List, Dict, Set, Tuple, Any
from enum import Enum

class MergingStrategy(Enum):
    """Different merging strategies"""
    NO_MERGING = "no_merging"  # Old implementation: reject if relationship exists
    INHERITANCE_MERGING = "inheritance_merging"  # Current: inherit relationships
    LIMITED_MERGING = "limited_merging"  # Limited component size
    CONFIDENCE_BASED = "confidence_based"  # Merge based on confidence

class BaseMergingStrategy:
    """Base class for merging strategies"""
    
    def __init__(self, strategy: MergingStrategy):
        self.strategy = strategy
    
    def merge_nodes(self, relationship_matrix: np.ndarray, codes: List[str]) -> Tuple[np.ndarray, List[str], Dict[str, str]]:
        """Merge nodes according to the strategy"""
        raise NotImplementedError
    
    def get_strategy_name(self) -> str:
        return self.strategy.value

class NoMergingStrategy(BaseMergingStrategy):
    """Old implementation: No merging, just reject conflicting relationships"""
    
    def __init__(self):
        super().__init__(MergingStrategy.NO_MERGING)
    
    def merge_nodes(self, relationship_matrix: np.ndarray, codes: List[str]) -> Tuple[np.ndarray, List[str], Dict[str, str]]:
        """No merging - return original matrix and codes"""
        print("🔄 Using NO MERGING strategy (old implementation)")
        print("   - No mutual relationship merging")
        print("   - All relationships preserved")
        print("   - Conflicts resolved by rejection")
        
        # Return original matrix and codes with identity mapping
        code_mapping = {code: code for code in codes}
        return relationship_matrix, codes, code_mapping

class InheritanceMergingStrategy(BaseMergingStrategy):
    """Current implementation: Full inheritance merging"""
    
    def __init__(self):
        super().__init__(MergingStrategy.INHERITANCE_MERGING)
    
    def merge_nodes(self, relationship_matrix: np.ndarray, codes: List[str]) -> Tuple[np.ndarray, List[str], Dict[str, str]]:
        """Full inheritance merging (current implementation)"""
        print("🔄 Using INHERITANCE MERGING strategy (current implementation)")
        print("   - Full mutual relationship merging")
        print("   - Relationships inherited by representative nodes")
        print("   - May cause over-merging in large components")
        
        n_codes = len(codes)
        
        # Create node groups and mapping
        node_groups = {code: [code] for code in codes}
        node_to_representative = {code: code for code in codes}
        
        # Find mutual relationships
        mutual_pairs = []
        for i in range(n_codes):
            for j in range(i+1, n_codes):
                if relationship_matrix[i, j] == 3 and relationship_matrix[j, i] == 3:
                    mutual_pairs.append((codes[i], codes[j]))
        
        print(f"   Found {len(mutual_pairs)} mutual relationships to merge")
        
        # Merge nodes with mutual relationships
        for code_a, code_b in mutual_pairs:
            rep_a = node_to_representative[code_a]
            rep_b = node_to_representative[code_b]
            
            if rep_a != rep_b:
                # Merge the groups
                if len(node_groups[rep_a]) >= len(node_groups[rep_b]):
                    # rep_a becomes the representative
                    for node in node_groups[rep_b]:
                        node_to_representative[node] = rep_a
                        node_groups[rep_a].append(node)
                    del node_groups[rep_b]
                else:
                    # rep_b becomes the representative
                    for node in node_groups[rep_a]:
                        node_to_representative[node] = rep_b
                        node_groups[rep_b].append(node)
                    del node_groups[rep_a]
        
        # Get final representative nodes
        representative_nodes = list(node_groups.keys())
        print(f"   Merged into {len(representative_nodes)} representative nodes")
        
        # Create new matrix with inherited relationships
        n_repr = len(representative_nodes)
        new_matrix = np.zeros((n_repr, n_repr), dtype=int)
        
        # Create mapping from old indices to new indices
        repr_to_idx = {repr_node: idx for idx, repr_node in enumerate(representative_nodes)}
        
        # Fill new matrix by aggregating relationships
        for i, code_i in enumerate(codes):
            for j, code_j in enumerate(codes):
                if relationship_matrix[i, j] > 0:  # If there's a relationship
                    target_i = node_to_representative[code_i]
                    target_j = node_to_representative[code_j]
                    
                    if target_i in repr_to_idx and target_j in repr_to_idx:
                        new_i = repr_to_idx[target_i]
                        new_j = repr_to_idx[target_j]
                        
                        # Aggregate relationships (take the strongest relationship)
                        current_value = new_matrix[new_i, new_j]
                        new_value = relationship_matrix[i, j]
                        new_matrix[new_i, new_j] = max(current_value, new_value)
        
        return new_matrix, representative_nodes, node_to_representative

class LimitedMergingStrategy(BaseMergingStrategy):
    """Limited merging to prevent over-merging"""
    
    def __init__(self, max_component_size: int = 3):
        super().__init__(MergingStrategy.LIMITED_MERGING)
        self.max_component_size = max_component_size
    
    def merge_nodes(self, relationship_matrix: np.ndarray, codes: List[str]) -> Tuple[np.ndarray, List[str], Dict[str, str]]:
        """Limited merging to prevent over-merging"""
        print(f"🔄 Using LIMITED MERGING strategy (max size: {self.max_component_size})")
        print("   - Limited mutual relationship merging")
        print("   - Prevents over-merging in large components")
        print("   - Preserves more relationships")
        
        n_codes = len(codes)
        
        # Create a graph for mutual relationships
        G = nx.Graph()
        for i in range(n_codes):
            for j in range(i+1, n_codes):
                if relationship_matrix[i, j] == 3 and relationship_matrix[j, i] == 3:
                    G.add_edge(codes[i], codes[j])
        
        # Find connected components
        components = list(nx.connected_components(G))
        print(f"   Found {len(components)} mutual components")
        
        # Limit component size to prevent over-merging
        limited_components = []
        for comp in components:
            if len(comp) <= self.max_component_size:
                limited_components.append(comp)
            else:
                # Break large components into smaller ones
                comp_list = list(comp)
                for i in range(0, len(comp_list), self.max_component_size):
                    limited_components.append(set(comp_list[i:i+self.max_component_size]))
        
        print(f"   After limiting to max size {self.max_component_size}: {len(limited_components)} components")
        
        # Create node mapping
        node_to_representative = {}
        for comp in limited_components:
            rep = list(comp)[0]  # Use first node as representative
            for node in comp:
                node_to_representative[node] = rep
        
        # Add nodes that weren't in any component
        for code in codes:
            if code not in node_to_representative:
                node_to_representative[code] = code
        
        # Get final representative nodes
        representative_nodes = list(set(node_to_representative.values()))
        print(f"   Final representative nodes: {len(representative_nodes)}")
        
        # Create new matrix with inherited relationships
        n_repr = len(representative_nodes)
        new_matrix = np.zeros((n_repr, n_repr), dtype=int)
        
        # Create mapping from old indices to new indices
        repr_to_idx = {repr_node: idx for idx, repr_node in enumerate(representative_nodes)}
        
        # Fill new matrix by aggregating relationships
        for i, code_i in enumerate(codes):
            for j, code_j in enumerate(codes):
                if relationship_matrix[i, j] > 0:  # If there's a relationship
                    target_i = node_to_representative[code_i]
                    target_j = node_to_representative[code_j]
                    
                    if target_i in repr_to_idx and target_j in repr_to_idx:
                        new_i = repr_to_idx[target_i]
                        new_j = repr_to_idx[target_j]
                        
                        # Aggregate relationships (take the strongest relationship)
                        current_value = new_matrix[new_i, new_j]
                        new_value = relationship_matrix[i, j]
                        new_matrix[new_i, new_j] = max(current_value, new_value)
        
        return new_matrix, representative_nodes, node_to_representative

class MergingStrategyFactory:
    """Factory for creating merging strategies"""
    
    @staticmethod
    def create_strategy(strategy_type: MergingStrategy, **kwargs) -> BaseMergingStrategy:
        """Create a merging strategy based on type"""
        if strategy_type == MergingStrategy.NO_MERGING:
            return NoMergingStrategy()
        elif strategy_type == MergingStrategy.INHERITANCE_MERGING:
            return InheritanceMergingStrategy()
        elif strategy_type == MergingStrategy.LIMITED_MERGING:
            max_size = kwargs.get('max_component_size', 3)
            return LimitedMergingStrategy(max_size)
        else:
            raise ValueError(f"Unknown strategy type: {strategy_type}")

def compare_merging_strategies(relationship_matrix: np.ndarray, codes: List[str]) -> Dict[str, Any]:
    """Compare different merging strategies"""
    print("🔍 COMPARING MERGING STRATEGIES")
    print("=" * 60)
    
    strategies = [
        MergingStrategy.NO_MERGING,
        MergingStrategy.INHERITANCE_MERGING,
        MergingStrategy.LIMITED_MERGING
    ]
    
    results = {}
    
    for strategy_type in strategies:
        print(f"\n📊 Testing {strategy_type.value}:")
        
        # Create strategy
        strategy = MergingStrategyFactory.create_strategy(strategy_type)
        
        # Apply merging
        new_matrix, new_codes, code_mapping = strategy.merge_nodes(relationship_matrix, codes)
        
        # Count relationships
        a_into_b_count = np.sum(new_matrix == 1)
        mutual_count = np.sum(new_matrix == 3) // 2  # Divide by 2 since mutual is bidirectional
        
        results[strategy_type.value] = {
            'original_codes': len(codes),
            'final_codes': len(new_codes),
            'a_into_b_relationships': a_into_b_count,
            'mutual_relationships': mutual_count,
            'total_relationships': a_into_b_count + mutual_count,
            'code_mapping': code_mapping
        }
        
        print(f"   Original codes: {len(codes)}")
        print(f"   Final codes: {len(new_codes)}")
        print(f"   A_into_B relationships: {a_into_b_count}")
        print(f"   Mutual relationships: {mutual_count}")
        print(f"   Total relationships: {a_into_b_count + mutual_count}")
    
    return results

if __name__ == "__main__":
    # Test the strategies
    print("🧪 Testing merging strategies")
    
    # Create a test matrix
    codes = ['A', 'B', 'C', 'D', 'E', 'F']
    matrix = np.zeros((6, 6), dtype=int)
    
    # Add mutual relationships
    matrix[0, 1] = 3  # A ↔ B
    matrix[1, 0] = 3
    matrix[2, 3] = 3  # C ↔ D
    matrix[3, 2] = 3
    
    # Add A_into_B relationships
    matrix[0, 2] = 1  # A → C
    matrix[1, 4] = 1  # B → E
    matrix[2, 5] = 1  # C → F
    matrix[3, 5] = 1  # D → F
    
    print("📊 Test matrix created")
    compare_merging_strategies(matrix, codes) 