import networkx as nx
from utils.util import distance, density
from objective_function import objective_function

def charikar_densest_subgraph_with_bounds(G, min_size=None, max_size=None):
    """
    Charikar's 2-approximation algorithm with strict size constraints.
    
    Args:
        G: NetworkX graph
        min_size: minimum subgraph size (inclusive)
        max_size: maximum subgraph size (inclusive)
    
    Returns:
        (best_nodes, best_density)
    """
    H = G.copy()
    best_nodes = set()
    best_density = -float('inf')
    
    while H.number_of_nodes() > 0:
        n = H.number_of_nodes()
        m = H.number_of_edges()
        
        # STRICT size constraint check - only consider subgraphs within bounds
        if (min_size is None or n >= min_size) and (max_size is None or n <= max_size):
            density_val = (2.0 * m) / n if n > 0 else 0
            if density_val > best_density:
                best_density = density_val
                best_nodes = set(H.nodes)
        
        # Remove minimum degree node
        if H.number_of_nodes() > 0:
            min_degree_node = min(H.degree, key=lambda x: x[1])[0]
            H.remove_node(min_degree_node)
    
    return best_nodes, best_density

def charikar_densest_connected_subgraph_with_bounds(G, min_size=None, max_size=None):
    """
    Charikar's algorithm modified to ensure connected subgraphs with strict size constraints.
    
    Args:
        G: NetworkX graph
        min_size: minimum subgraph size (inclusive)
        max_size: maximum subgraph size (inclusive)
    
    Returns:
        (best_nodes, best_density)
    """
    H = G.copy()
    best_nodes = set()
    best_density = -float('inf')
    
    while H.number_of_nodes() > 0:
        n = H.number_of_nodes()
        m = H.number_of_edges()
        
        # Only consider connected subgraphs within size bounds
        if nx.is_connected(H):
            # STRICT size constraint check
            if (min_size is None or n >= min_size) and (max_size is None or n <= max_size):
                density_val = (2.0 * m) / n if n > 0 else 0
                if density_val > best_density:
                    best_density = density_val
                    best_nodes = set(H.nodes)
        
        # Remove minimum degree node
        if H.number_of_nodes() > 0:
            min_degree_node = min(H.degree, key=lambda x: x[1])[0]
            H.remove_node(min_degree_node)
    
    return best_nodes, best_density

def is_subset_of_any(subgraph, existing_subgraphs):
    """
    Check if a subgraph is a subset of any existing subgraph.
    
    Args:
        subgraph: NetworkX subgraph
        existing_subgraphs: List of existing subgraphs
    
    Returns:
        True if subgraph is a subset of any existing subgraph
    """
    subgraph_nodes = set(subgraph.nodes())
    for existing_sg in existing_subgraphs:
        if subgraph_nodes.issubset(set(existing_sg.nodes())):
            return True
    return False

def calculate_diversity_score(subgraph, existing_subgraphs, lambda_param):
    """
    Calculate diversity score for a subgraph relative to existing subgraphs.
    
    Args:
        subgraph: NetworkX subgraph
        existing_subgraphs: List of existing subgraphs
        lambda_param: Trade-off parameter
    
    Returns:
        Diversity score
    """
    if not existing_subgraphs:
        return 0.0
    
    total_distance = 0.0
    for existing_sg in existing_subgraphs:
        total_distance += distance(subgraph, existing_sg)
    
    # Normalize by number of existing subgraphs
    avg_distance = total_distance / len(existing_subgraphs)
    
    # Apply lambda parameter to control diversity importance
    return lambda_param * avg_distance

def find_connected_subgraphs_with_diversity_optimization(G, W, lambda_param, min_size, max_size, num_candidates=15):
    """
    Find connected subgraphs that optimize the combined objective function.
    
    This function generates multiple candidate connected subgraphs and selects the one that maximizes
    the objective function: density + diversity_score, where diversity is properly controlled by lambda.
    
    Args:
        G: NetworkX graph
        W: List of existing subgraphs
        lambda_param: Trade-off parameter (high = favor diversity, low = favor density)
        min_size: Minimum subgraph size
        max_size: Maximum subgraph size
        num_candidates: Number of candidate subgraphs to generate
        
    Returns:
        Best connected subgraph according to objective function
    """
    best_subgraph = None
    best_objective_value = -float('inf')
    
    # Generate multiple candidate subgraphs using different approaches
    candidates = []
    
    # Approach 1: Pure densest connected subgraph (favor density when lambda is low)
    densest_nodes, _ = charikar_densest_connected_subgraph_with_bounds(G, min_size, max_size)
    if densest_nodes:
        candidates.append(G.subgraph(densest_nodes).copy())
    
    # Approach 2: Find connected subgraphs that are diverse from existing ones
    if W and lambda_param > 0.5:  # Only when we have existing subgraphs and want diversity
        existing_nodes = set()
        for sg in W:
            existing_nodes.update(sg.nodes())
        
        # Create a modified graph with reduced weights for existing nodes
        G_diverse = G.copy()
        
        # Remove some percentage of existing nodes based on lambda
        # Higher lambda = remove more nodes to encourage diversity
        removal_ratio = min(0.6, lambda_param / 5.0)  # Cap at 60% removal
        nodes_to_remove = list(existing_nodes)[:int(len(existing_nodes) * removal_ratio)]
        
        if nodes_to_remove:
            G_diverse.remove_nodes_from(nodes_to_remove)
            diverse_nodes, _ = charikar_densest_connected_subgraph_with_bounds(G_diverse, min_size, max_size)
            if diverse_nodes:
                candidates.append(G.subgraph(diverse_nodes).copy())
    
    # Approach 3: Generate connected subgraphs with different size constraints
    for size_offset in range(-3, 4):  # Wider range
        temp_min = max(1, min_size + size_offset)
        temp_max = max_size + size_offset
        if temp_min <= temp_max and temp_max <= G.number_of_nodes():  # Ensure max_size doesn't exceed graph size
            temp_nodes, _ = charikar_densest_connected_subgraph_with_bounds(G, temp_min, temp_max)
            if temp_nodes:
                candidates.append(G.subgraph(temp_nodes).copy())
    
    # Approach 4: Find connected components and select the densest ones
    connected_components = list(nx.connected_components(G))
    for component in connected_components:
        if min_size <= len(component) <= max_size:
            component_subgraph = G.subgraph(component).copy()
            if nx.is_connected(component_subgraph):
                candidates.append(component_subgraph)
    
    # Approach 5: Grow subgraphs from high-degree nodes
    high_degree_nodes = sorted(G.degree, key=lambda x: x[1], reverse=True)[:8]  # More starting points
    for node, degree in high_degree_nodes:
        if degree >= 2:
            current_nodes = {node}
            neighbors = set(G.neighbors(node))
            
            # Grow until we reach size limit
            while len(current_nodes) < max_size and neighbors:
                best_neighbor = max(neighbors, key=lambda n: G.degree(n))
                current_nodes.add(best_neighbor)
                neighbors.remove(best_neighbor)
                
                new_neighbors = set(G.neighbors(best_neighbor)) - current_nodes
                neighbors.update(new_neighbors)
            
            if min_size <= len(current_nodes) <= max_size:
                candidate_subgraph = G.subgraph(current_nodes).copy()
                if nx.is_connected(candidate_subgraph):
                    candidates.append(candidate_subgraph)
    
    # Approach 6: Random sampling with size constraints
    import random
    for _ in range(3):  # Try 3 random samples
        if G.number_of_nodes() >= min_size:
            # Start with a random node
            start_node = random.choice(list(G.nodes()))
            current_nodes = {start_node}
            neighbors = set(G.neighbors(start_node))
            
            # Grow randomly
            while len(current_nodes) < max_size and neighbors:
                if random.random() < 0.7:  # 70% chance to add neighbor
                    neighbor = random.choice(list(neighbors))
                    current_nodes.add(neighbor)
                    neighbors.remove(neighbor)
                    
                    new_neighbors = set(G.neighbors(neighbor)) - current_nodes
                    neighbors.update(new_neighbors)
                else:
                    break
            
            if min_size <= len(current_nodes) <= max_size:
                candidate_subgraph = G.subgraph(current_nodes).copy()
                if nx.is_connected(candidate_subgraph):
                    candidates.append(candidate_subgraph)
    
    # Evaluate all candidates using the improved objective function
    for candidate in candidates:
        if (candidate.number_of_nodes() >= min_size and 
            candidate.number_of_nodes() <= max_size and 
            nx.is_connected(candidate) and
            not is_subset_of_any(candidate, W)):  # Prevent subset relationships
            
            # Calculate density
            candidate_density = density(candidate)
            
            # Calculate diversity score
            diversity_score = calculate_diversity_score(candidate, W, lambda_param)
            
            # Combined objective: density + diversity
            objective_value = candidate_density + diversity_score
            
            if objective_value > best_objective_value:
                best_objective_value = objective_value
                best_subgraph = candidate

    return best_subgraph

def densest_distinct_subgraph(G, W, lambda_param, min_subset_size, max_subset_size):
    """
    Identify the connected subgraph that optimizes the combined objective function.
    
    This function finds the connected subgraph that maximizes: density + diversity_score
    where diversity is properly controlled by lambda parameter.
    
    Args:
        G (object): A graph represented as networkx graph object
        W (list): Set of top-k subgraphs, k is less than the number of vertices in the graph
        lambda_param (float): Number that controls the trade-off between density and diversity of the subgraphs
        min_subset_size (int): Minimum number of vertices in a subgraph
        max_subset_size (int): Maximum number of vertices in a subgraph
        
    Returns:
        The connected subgraph that maximizes the objective function. Returns `None` 
        if no valid subgraph is found.
    """
    return find_connected_subgraphs_with_diversity_optimization(
        G, W, lambda_param, min_subset_size, max_subset_size
    )

def find_distinct_densest_subgraphs_iterative(G, k, lambda_param, min_size, max_size, max_attempts=15):
    """
    Find k connected subgraphs that optimize the combined objective function.
    
    This function iteratively finds connected subgraphs that maximize the objective function:
    density + diversity_score, with proper size constraints and subset prevention.
    
    Args:
        G: NetworkX graph
        k: number of subgraphs to find
        lambda_param: trade-off parameter (high = favor diversity, low = favor density)
        min_size: minimum subgraph size
        max_size: maximum subgraph size
        max_attempts: maximum attempts to find subgraphs
    
    Returns:
        List of connected subgraphs that optimize the objective function
    """
    subgraphs = []
    G_working = G.copy()
    
    for i in range(k):
        attempts = 0
        found_subgraph = False
        
        while attempts < max_attempts and not found_subgraph:
            # Find connected subgraph that optimizes the objective function
            subgraph = densest_distinct_subgraph(
                G_working, subgraphs, lambda_param, min_size, max_size
            )
            
            if (subgraph is not None and 
                subgraph.number_of_nodes() > 0 and 
                nx.is_connected(subgraph) and
                min_size <= subgraph.number_of_nodes() <= max_size and
                not is_subset_of_any(subgraph, subgraphs)):
                
                subgraphs.append(subgraph)
                found_subgraph = True
                
                # Calculate objective components
                subgraph_density = density(subgraph)
                diversity_score = calculate_diversity_score(subgraph, subgraphs[:-1], lambda_param)
                total_objective = subgraph_density + diversity_score
                
                print(f"Found connected subgraph {i+1}: {subgraph.number_of_nodes()} nodes, "
                      f"density {subgraph_density:.4f}, "
                      f"diversity {diversity_score:.4f}, "
                      f"objective {total_objective:.4f}")
            else:
                attempts += 1
                
                # If no subgraph found, try relaxing constraints
                if attempts <= max_attempts // 2:
                    # Relax size constraints slightly
                    min_size_temp = max(1, min_size - attempts)
                    max_size_temp = min(max_size + attempts, G_working.number_of_nodes())
                    subgraph = densest_distinct_subgraph(
                        G_working, subgraphs, lambda_param, min_size_temp, max_size_temp
                    )
                    if (subgraph is not None and 
                        subgraph.number_of_nodes() > 0 and 
                        nx.is_connected(subgraph) and
                        not is_subset_of_any(subgraph, subgraphs)):
                        
                        subgraphs.append(subgraph)
                        found_subgraph = True
                        print(f"Found connected subgraph {i+1} (relaxed): {subgraph.number_of_nodes()} nodes")
                else:
                    # Try with a different approach - remove some nodes to encourage diversity
                    if subgraphs:
                        recent_nodes = list(subgraphs[-1].nodes())
                        nodes_to_remove = recent_nodes[:len(recent_nodes)//4]  # Remove 1/4
                        G_working.remove_nodes_from(nodes_to_remove)
        
        if not found_subgraph:
            print(f"Could not find connected subgraph {i+1} after {max_attempts} attempts")
            break
    
    return subgraphs


