import torch
import numpy as np
import random
from utils import sample_critical_pairs, prepare_graph_data

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Global parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def predictive_path_stressing(G, K, T, edge_functions, budget_bounds, spagan_model, data, epsilon=0.01):
    """
    Implements the Predictive Path Stressing (PPS) algorithm to find a feasible perturbation
    
    Args:
        G: NetworkX graph
        K: List of (source, target) pairs
        T: Threshold for path costs
        edge_functions: Dictionary mapping edge (u,v) to a function f_e(x)
        budget_bounds: Dictionary mapping edge (u,v) to maximum budget
        spagan_model: Trained SPAGAN model
        data: PyTorch Geometric Data object of the graph
        epsilon: Epsilon parameter for the soft constraint
    
    Returns:
        perturbation: Dictionary mapping edges to perturbation values
    """
    # Initialize perturbation vector (all zeros)
    num_edges = G.number_of_edges()
    x = torch.zeros(num_edges, device=device)
    
    # Create edge index mapping for the perturbation vector
    edge_to_idx = {}
    for i, (u, v) in enumerate(G.edges()):
        edge_to_idx[(u, v)] = i
    
    # Helper function to apply edge functions to current perturbation
    def apply_edge_function(edge, x_val):
        u, v = edge
        if callable(edge_functions):
            # Global function
            return edge_functions(x_val)
        elif isinstance(edge_functions, dict):
            # Edge-specific function
            return edge_functions.get(edge, lambda x: x)(x_val)
        else:
            # Default linear function
            return x_val
    
    # Main PPS loop
    iterations = 0
    max_iterations = 1000  # Safety limit
    
    while iterations < max_iterations:
        iterations += 1
        
        # Find violating pairs (paths whose cost is below threshold T)
        violating_pairs = []
        for s, t in K:
            s_idx = torch.tensor([s], device=device)
            t_idx = torch.tensor([t], device=device)
            
            # Get predicted cost using SPAGAN
            pred_cost = spagan_model(data, s_idx, t_idx, perturbation=x).item()
            
            if pred_cost < T:
                violating_pairs.append((s, t))
        
        # If no violations, we're done
        if not violating_pairs:
            break
        
        # Create set P of shortest paths for violating pairs
        P = []
        for s, t in violating_pairs:
            s_idx = torch.tensor([s], device=device)
            t_idx = torch.tensor([t], device=device)
            path = spagan_model.predict_path(data, s_idx, t_idx, perturbation=x)
            if path:  # Only include if path exists
                P.append(path)
        
        # If P is empty (all paths are disconnected), exit
        if not P:
            break
        
        # Define potential function as in paper
        def calculate_potential(P, perturbation):
            potential = 0.0
            for path in P:
                s, t = path[0], path[-1]
                s_idx = torch.tensor([s], device=device)
                t_idx = torch.tensor([t], device=device)
                
                pred_cost = spagan_model(data, s_idx, t_idx, perturbation=perturbation).item()
                potential += min(T, pred_cost)
            return potential
        
        # Calculate current potential
        current_potential = calculate_potential(P, x)
        
        # Target potential (|P| * T - epsilon)
        target_potential = len(P) * T - epsilon
        
        # Continue improving potential until we reach the target
        potential_iterations = 0
        while current_potential < target_potential and potential_iterations < 100:
            potential_iterations += 1
            
            # Find the edge with best gain-to-cost ratio
            best_edge = None
            best_delta = None
            best_ratio = -float('inf')
            
            # Collect all edges from paths in P
            path_edges = set()
            for path in P:
                for i in range(len(path) - 1):
                    path_edges.add((path[i], path[i+1]))
            
            # Evaluate each edge
            for u, v in path_edges:
                edge_idx = edge_to_idx.get((u, v))
                if edge_idx is None:
                    continue
                
                # Try different increment values
                max_budget = budget_bounds.get((u, v), T) if isinstance(budget_bounds, dict) else T
                current_x_e = x[edge_idx].item()
                
                for delta in range(1, int(max_budget - current_x_e) + 1):
                    # Create new perturbation
                    x_new = x.clone()
                    x_new[edge_idx] += delta
                    
                    # Calculate new potential
                    new_potential = calculate_potential(P, x_new)
                    
                    # Calculate gain-to-cost ratio
                    gain = new_potential - current_potential
                    ratio = gain / delta if delta > 0 else 0
                    
                    if ratio > best_ratio:
                        best_ratio = ratio
                        best_edge = (u, v)
                        best_delta = delta
            
            # If we found an improvement, apply it
            if best_edge and best_delta:
                edge_idx = edge_to_idx[best_edge]
                x[edge_idx] += best_delta
                current_potential = calculate_potential(P, x)
            else:
                # No improvement possible
                break
        
        # If we couldn't improve the potential enough, break
        if current_potential < target_potential:
            break
    
    # Convert the perturbation tensor to a dictionary
    perturbation_dict = {}
    for (u, v), idx in edge_to_idx.items():
        if x[idx].item() > 0:
            perturbation_dict[(u, v)] = x[idx].item()
    
    return perturbation_dict, x


def forge_phase(graphs, critical_pairs_list, thresholds, spagan_model, edge_functions=None, budget_bounds=None):
    """
    Implements the Forge phase to generate a pre-trained solution dataset
    
    Args:
        graphs: List of NetworkX graphs
        critical_pairs_list: List of lists of (source, target) pairs for each graph
        thresholds: List of thresholds T to test
        spagan_model: Trained SPAGAN model
        edge_functions: Functions to calculate edge weights (default: linear)
        budget_bounds: Maximum budget for each edge (default: T)
    
    Returns:
        solution_dataset: List of (graph, critical_pairs, threshold, perturbation) tuples
    """
    solution_dataset = []
    
    for i, G in enumerate(graphs):
        print(f"Processing graph {i+1}/{len(graphs)}")
        
        # Create PyG Data object for the graph
        data, node_mapping = prepare_graph_data(G)
        data = data.to(device)
        
        # Get critical pairs for this graph
        if i < len(critical_pairs_list):
            K_list = critical_pairs_list[i]
        else:
            K_list = sample_critical_pairs(G, num_pairs=5)
            K_list = [K_list]  # Make it a list of lists of pairs
        
        # Process each pair individually with each threshold
        for K in K_list:
            # Handle the case where K is a single tuple, not a list of tuples
            if isinstance(K, tuple) and len(K) == 2:
                K = [K]  # Convert to a list containing one pair
                
            # The print statement should be inside the thresholds loop
            for T in thresholds:
                print(f"  Finding solution for threshold T={T} with {len(K)} critical pairs")
                
                # Map node IDs to indices if needed
                if node_mapping:
                    K_mapped = []
                    for s, t in K:
                        if s in node_mapping and t in node_mapping:
                            K_mapped.append((node_mapping[s], node_mapping[t]))
                        else:
                            print(f"Warning: Node {s} or {t} not found in node_mapping")
                else:
                    K_mapped = K
                
                if not K_mapped:
                    print("  No valid critical pairs after mapping, skipping")
                    continue
                    
                # Generate solution using PPS
                try:
                    perturbation_dict, perturbation_vector = predictive_path_stressing(
                        G, K_mapped, T, edge_functions, budget_bounds, 
                        spagan_model, data
                    )
                    
                    # Store the solution
                    solution_dataset.append((G, K, T, perturbation_dict, perturbation_vector))
                    
                    # Print statistics 
                    total_budget = sum(perturbation_dict.values())
                    print(f"    Solution found with total budget: {total_budget:.2f}")
                except Exception as e:
                    print(f"    Error generating solution: {e}")
    
    return solution_dataset