import torch
import numpy as np
import networkx as nx
import random
from utils import prepare_graph_data, encode_graph_condition


# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Global parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def pps_inference(G, K, T, edge_functions, budget_bounds, initial_perturbation=None):
    """
    PPS-I algorithm - PPS variant for inference with exact shortest path computation
    
    Args:
        G: NetworkX graph
        K: List of (source, target) pairs
        T: Threshold
        edge_functions: Functions to calculate edge weights
        budget_bounds: Maximum budget for each edge
        initial_perturbation: Initial perturbation (optional)
    
    Returns:
        perturbation_dict: Dictionary mapping edges to perturbation values
    """
    # Initialize perturbation
    if initial_perturbation is None:
        perturbation = {(u, v): 0.0 for u, v in G.edges()}
    else:
        perturbation = initial_perturbation.copy()
        
        # Ensure all edges have a value
        for u, v in G.edges():
            if (u, v) not in perturbation:
                perturbation[(u, v)] = 0.0
    
    # Helper function to apply edge functions
    def apply_edge_function(edge, x_val):
        u, v = edge
        if callable(edge_functions):
            # Global function
            return edge_functions(x_val)
        elif isinstance(edge_functions, dict):
            # Edge-specific function
            return edge_functions.get(edge, lambda x: x)(x_val)
        else:
            # Default linear function
            return x_val
    
    # Helper function to compute path cost
    def compute_path_cost(path):
        cost = 0.0
        for i in range(len(path) - 1):
            u, v = path[i], path[i+1]
            edge_weight = G[u][v].get('weight', 1.0)
            perturb_val = perturbation.get((u, v), 0.0)
            cost += edge_weight + apply_edge_function((u, v), perturb_val)
        return cost
    
    # Create a copy of the graph for path computation
    G_copy = G.copy()
    
    # Main loop
    iterations = 0
    max_iterations = 1000
    
    while iterations < max_iterations:
        iterations += 1
        
        # Update edge weights in the graph
        for u, v in G_copy.edges():
            base_weight = G[u][v].get('weight', 1.0)
            perturb_val = perturbation.get((u, v), 0.0)
            G_copy[u][v]['weight'] = base_weight + apply_edge_function((u, v), perturb_val)
        
        # Find violating pairs
        violating_pairs = []
        for s, t in K:
            try:
                path = nx.shortest_path(G_copy, source=s, target=t, weight='weight')
                path_cost = compute_path_cost(path)
                
                if path_cost < T:
                    violating_pairs.append((s, t, path))
            except nx.NetworkXNoPath:
                # No path exists, not a violation
                pass
        
        # If no violations, we're done
        if not violating_pairs:
            break
        
        # Define potential function
        def calculate_potential(perturb):
            # Make a temporary copy of perturbation
            temp_perturb = perturbation.copy()
            temp_perturb.update(perturb)
            
            # Update edge weights temporarily
            for (u, v), val in perturb.items():
                base_weight = G[u][v].get('weight', 1.0)
                G_copy[u][v]['weight'] = base_weight + apply_edge_function((u, v), val)
            
            # Calculate potential
            potential = 0.0
            for s, t, path in violating_pairs:
                try:
                    new_path = nx.shortest_path(G_copy, source=s, target=t, weight='weight')
                    path_cost = sum(G_copy[new_path[i]][new_path[i+1]]['weight'] for i in range(len(new_path)-1))
                    potential += min(T, path_cost)
                except nx.NetworkXNoPath:
                    # If no path exists, set to threshold
                    potential += T
            
            # Restore original weights
            for u, v in G_copy.edges():
                base_weight = G[u][v].get('weight', 1.0)
                perturb_val = perturbation.get((u, v), 0.0)
                G_copy[u][v]['weight'] = base_weight + apply_edge_function((u, v), perturb_val)
            
            return potential
        
        # Current potential
        current_potential = calculate_potential({})
        
        # Target potential
        target_potential = len(violating_pairs) * T - 0.01
        
        # Find the best edge-delta pair
        best_edge = None
        best_delta = None
        best_ratio = -float('inf')
        
        # Collect edges from violating paths
        path_edges = set()
        for _, _, path in violating_pairs:
            for i in range(len(path) - 1):
                path_edges.add((path[i], path[i+1]))
        
        # Try different increments for each edge
        for u, v in path_edges:
            max_budget = budget_bounds.get((u, v), T) if isinstance(budget_bounds, dict) else T
            current_val = perturbation.get((u, v), 0.0)
            
            for delta in range(1, int(max_budget - current_val) + 1):
                # Create test perturbation
                test_perturb = {(u, v): current_val + delta}
                
                # Calculate new potential
                new_potential = calculate_potential(test_perturb)
                
                # Calculate gain-to-cost ratio
                gain = new_potential - current_potential
                ratio = gain / delta
                
                if ratio > best_ratio:
                    best_ratio = ratio
                    best_edge = (u, v)
                    best_delta = delta
        
        # If we found an improvement, apply it
        if best_edge and best_delta:
            u, v = best_edge
            perturbation[(u, v)] = perturbation.get((u, v), 0.0) + best_delta
            
            # Update the graph
            base_weight = G[u][v].get('weight', 1.0)
            G_copy[u][v]['weight'] = base_weight + apply_edge_function(best_edge, perturbation[(u, v)])
        else:
            # No improvement possible
            break
    
    # Return the perturbation dictionary (filtering out zeros)
    result = {edge: val for edge, val in perturbation.items() if val > 0}
    return result


def hephaestus_inference(G, K, T, spagan_model, mix_cvae, policy=None, edge_functions=None, budget_bounds=None):
    """
    Complete Hephaestus inference pipeline
    
    Args:
        G: NetworkX graph
        K: List of (source, target) pairs
        T: Threshold
        spagan_model: Trained SPAGAN model
        mix_cvae: Trained Mix-CVAE model
        policy: Trained RL policy (optional)
        edge_functions: Edge weight functions
        budget_bounds: Maximum budget per edge
    
    Returns:
        perturbation: Dictionary mapping edges to perturbation values
    """
    # Step 1: Encode graph and condition
    data, node_mapping = prepare_graph_data(G)
    data = data.to(device)
    
    # Map critical pairs
    if node_mapping:
        K_mapped = [(node_mapping[s], node_mapping[t]) for s, t in K]
    else:
        K_mapped = K
    
    # Encode condition
    condition = encode_graph_condition(G, K, T).unsqueeze(0)
    
    # Step 2: Generate initial solution
    if policy is not None:
        # Use RL policy to refine
        latent_dim = mix_cvae.experts[0].latent_dim
        
        # Sample from prior
        z = torch.randn(1, latent_dim, device=device)
        
        # Refine latent vector with policy
        for _ in range(10):
            state = torch.cat([z, condition], dim=1)
            action, _, _ = policy.get_action(state, deterministic=True)
            z = z + action
        
        # Decode to get perturbation
        with torch.no_grad():
            perturbation_vector = mix_cvae.experts[0].decode(z, condition).squeeze()
    else:
        # Sample directly from Mix-CVAE
        with torch.no_grad():
            perturbation_vector = mix_cvae.sample(condition, num_samples=1).squeeze()
    
    # Step 3: Convert to dictionary
    initial_perturbation = {}
    for i, val in enumerate(perturbation_vector):
        if val.item() > 0:
            # Map back to edge (assuming sequential edge indexing)
            edge = list(G.edges())[i] if i < len(G.edges()) else None
            if edge:
                initial_perturbation[edge] = val.item()
    
    # Step 4: Ensure feasibility with PPS-I
    final_perturbation = pps_inference(
        G, K_mapped, T, 
        edge_functions=edge_functions,
        budget_bounds=budget_bounds,
        initial_perturbation=initial_perturbation
    )
    
    # Step 5: Validate solution
    is_feasible = True
    for s, t in K_mapped:
        # Create a modified graph with perturbations
        G_temp = G.copy()
        for (u, v), val in final_perturbation.items():
            base_weight = G_temp[u][v].get('weight', 1.0)
            
            # Apply edge function
            perturbed_weight = base_weight
            if callable(edge_functions):
                perturbed_weight += edge_functions(val)
            elif isinstance(edge_functions, dict) and (u, v) in edge_functions:
                perturbed_weight += edge_functions[(u, v)](val)
            else:
                perturbed_weight += val
                
            G_temp[u][v]['weight'] = perturbed_weight
        
        # Check shortest path cost
        try:
            path = nx.shortest_path(G_temp, s, t, weight='weight')
            path_cost = sum(G_temp[path[i]][path[i+1]]['weight'] for i in range(len(path)-1))
            
            if path_cost < T:
                is_feasible = False
                break
        except nx.NetworkXNoPath:
            # No path exists, which is a trivial satisfaction
            pass
    
    print(f"Solution feasibility: {is_feasible}")
    print(f"Total budget: {sum(final_perturbation.values()):.2f}")
    
    return final_perturbation

