import torch
import torch.nn.functional as F
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import random
from models import SPAGAN
from forge import forge_phase
from morph import morph_phase
from refine import refine_phase
from utils import sample_critical_pairs, generate_training_data
from inference import hephaestus_inference


# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Global parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def main():
    """
    Main function to run the entire Hephaestus framework
    """
    print("=== Hephaestus: QoS Degradation Framework ===")
    
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    
    # 1. Generate or load graphs
    print("\n=== Generating Test Graphs ===")
    
    # Generate synthetic graphs for testing
    graphs = []
    
    # Erdős–Rényi graph
    print("Generating Erdős–Rényi graph...")
    n_nodes = 500  # Smaller for faster testing
    edge_probability = 0.02
    G_er = nx.gnp_random_graph(n_nodes, edge_probability, directed=True)
    
    # Ensure connectivity
    if not nx.is_weakly_connected(G_er):
        largest_wcc = max(nx.weakly_connected_components(G_er), key=len)
        G_er = G_er.subgraph(largest_wcc).copy()
    
    # Add random weights
    for u, v in G_er.edges():
        G_er[u][v]['weight'] = random.uniform(1, 10)
    
    graphs.append(G_er)
    
    # Barabási-Albert graph (scale-free)
    print("Generating Barabási-Albert graph...")
    n_nodes = 500
    m_edges = 5
    G_ba = nx.barabasi_albert_graph(n_nodes, m_edges)
    G_ba = nx.DiGraph(G_ba)  # Convert to directed
    
    # Add random weights
    for u, v in G_ba.edges():
        G_ba[u][v]['weight'] = random.uniform(1, 10)
    
    graphs.append(G_ba)
    
    print(f"Created {len(graphs)} graphs for testing")
    
    # 2. Train SPAGAN
    print("\n=== Training SPAGAN Model ===")
    
    # Create training data
    print("Generating training data for SPAGAN...")
    train_loader, val_loader, node_mapping = generate_training_data(G_er, num_samples=1000, batch_size=32)
    
    # Initialize SPAGAN model
    print("Initializing SPAGAN model...")
    spagan_model = SPAGAN(
        input_dim=3,
        hidden_dim=128,
        n_layers=4,
        n_heads=8,
        dropout=0.2
    ).to(device)
    
    # Train the model
    print("Training SPAGAN model...")
    optimizer = torch.optim.Adam(spagan_model.parameters(), lr=0.001, weight_decay=1e-5)
    
    # Function to train the model (simplified from the full version)
    def simple_train_spagan(model, train_loader, val_loader, optimizer, epochs=50):
        model.to(device)
        
        for epoch in range(epochs):
            # Training phase
            model.train()
            total_train_loss = 0.0
            train_batches = 0
            
            for batch in train_loader:
                batch = batch.to(device)
                
                # Extract source-target pairs and true distances
                source_indices = batch.source_indices
                target_indices = batch.target_indices
                true_distances = batch.true_distances
                
                # Forward pass
                predicted_distances = model(batch, source_indices, target_indices, 
                                          batch.perturbation if hasattr(batch, 'perturbation') else None)
                
                # Huber loss for robustness to outliers
                loss = F.huber_loss(predicted_distances, true_distances)
                
                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_train_loss += loss.item()
                train_batches += 1
            
            avg_train_loss = total_train_loss / train_batches
            
            # Validation phase
            model.eval()
            total_val_loss = 0.0
            val_batches = 0
            
            with torch.no_grad():
                for batch in val_loader:
                    batch = batch.to(device)
                    
                    # Extract source-target pairs and true distances
                    source_indices = batch.source_indices
                    target_indices = batch.target_indices
                    true_distances = batch.true_distances
                    
                    # Forward pass
                    predicted_distances = model(batch, source_indices, target_indices, 
                                              batch.perturbation if hasattr(batch, 'perturbation') else None)
                    
                    # Huber loss for validation
                    loss = F.huber_loss(predicted_distances, true_distances)
                    
                    total_val_loss += loss.item()
                    val_batches += 1
            
            avg_val_loss = total_val_loss / val_batches
            
            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
        
        return model
    
    # Train the model
    spagan_model = simple_train_spagan(spagan_model, train_loader, val_loader, optimizer, epochs=50)
    print("SPAGAN training complete")
    
    # 3. Define edge functions and budget bounds
    print("\n=== Defining Edge Functions ===")
    
    # Define three types of edge functions as mentioned in the paper
    edge_functions = {}
    
    # Linear function: f_e(x) = x
    def linear_function(x):
        return x
    
    # Quadratic Convex function: f_e(x) = x^2
    def quadratic_function(x):
        return x**2
    
    # Log Concave function: f_e(x) = ln(x+1)
    def log_concave_function(x):
        return np.log(x + 1)
    
    # Choose function type for the test
    edge_function_type = "linear"  # Options: "linear", "quadratic", "log_concave"
    
    if edge_function_type == "linear":
        edge_function = linear_function
    elif edge_function_type == "quadratic":
        edge_function = quadratic_function
    else:
        edge_function = log_concave_function
    
    print(f"Using {edge_function_type} edge function")
    
    # Set max budget per edge
    budget_bound = 20.0
    print(f"Max budget per edge: {budget_bound}")
    
    # 4. Run Forge phase
    print("\n=== Running Forge Phase ===")
    
    # Sample critical pairs for each graph
    critical_pairs_list = []
    thresholds = [10.0, 15.0, 20.0]  # Different thresholds to test
    
    for G in graphs:
        # The issue is here - need to properly structure the critical pairs
        all_pairs = []
        for _ in range(5):  # Generate 5 pairs per graph
            pair = sample_critical_pairs(G, num_pairs=1)
            if pair:  # Make sure we got a valid pair
                all_pairs.append(pair[0])  # Add the single pair to our list
        
        critical_pairs_list.append(all_pairs)
        print(f"Generated {len(all_pairs)} critical pairs: {all_pairs}")


    # Run Forge phase to generate initial solutions
    solution_dataset = forge_phase(
        graphs=graphs,
        critical_pairs_list=critical_pairs_list,
        thresholds=thresholds,
        spagan_model=spagan_model,
        edge_functions=edge_function,
        budget_bounds=budget_bound
    )
    
    print(f"Generated {len(solution_dataset)} solution instances")
    
    # 5. Run Morph phase
    print("\n=== Running Morph Phase ===")
    
    # For faster testing, use smaller number of epochs
    ebm, mix_cvae = morph_phase(
        solution_dataset=solution_dataset,
        num_epochs=100,  # Reduced for demonstration
        initial_experts=1,
        max_experts=5
    )
    
    print("Morph phase complete")
    
    # 6. Run Refine phase
    print("\n=== Running Refine Phase ===")
    
    # For faster testing, use smaller number of episodes
    policy, refined_solutions = refine_phase(
        ebm=ebm,
        mix_cvae=mix_cvae,
        solution_dataset=solution_dataset,
        spagan_model=spagan_model,
        num_episodes=200  # Reduced for demonstration
    )
    
    print("Refine phase complete")
    print(f"Generated {len(refined_solutions)} refined solutions")
    
    # 7. Test on a new instance
    print("\n=== Testing on New Instance ===")
    
    # Create a new graph for testing
    print("Generating new test graph...")
    n_nodes = 300
    edge_probability = 0.015
    G_test = nx.gnp_random_graph(n_nodes, edge_probability, directed=True)
    
    # Ensure connectivity
    if not nx.is_weakly_connected(G_test):
        largest_wcc = max(nx.weakly_connected_components(G_test), key=len)
        G_test = G_test.subgraph(largest_wcc).copy()
    
    # Add random weights
    for u, v in G_test.edges():
        G_test[u][v]['weight'] = random.uniform(1, 10)
    
    # Sample critical pairs
    K_test = sample_critical_pairs(G_test, num_pairs=3)
    
    # Choose threshold
    T_test = 15.0
    
    print(f"Test graph has {G_test.number_of_nodes()} nodes and {G_test.number_of_edges()} edges")
    print(f"Testing with {len(K_test)} critical pairs and threshold T = {T_test}")
    
    # Get baseline (shortest path without perturbation)
    total_baseline_cost = 0.0
    for s, t in K_test:
        try:
            path = nx.shortest_path(G_test, s, t, weight='weight')
            path_cost = sum(G_test[path[i]][path[i+1]]['weight'] for i in range(len(path)-1))
            total_baseline_cost += path_cost
        except nx.NetworkXNoPath:
            pass
    
    avg_baseline_cost = total_baseline_cost / len(K_test) if K_test else 0
    print(f"Average baseline path cost: {avg_baseline_cost:.2f}")
    
    # Run inference without policy
    print("\nRunning inference WITHOUT policy refinement:")
    solution_without_policy = hephaestus_inference(
        G=G_test,
        K=K_test,
        T=T_test,
        spagan_model=spagan_model,
        mix_cvae=mix_cvae,
        policy=None,
        edge_functions=edge_function,
        budget_bounds=budget_bound
    )
    
    # Run inference with policy
    print("\nRunning inference WITH policy refinement:")
    solution_with_policy = hephaestus_inference(
        G=G_test,
        K=K_test,
        T=T_test,
        spagan_model=spagan_model,
        mix_cvae=mix_cvae,
        policy=policy,
        edge_functions=edge_function,
        budget_bounds=budget_bound
    )
    
    # Compare solutions
    budget_without_policy = sum(solution_without_policy.values())
    budget_with_policy = sum(solution_with_policy.values())
    
    print("\n=== Solution Comparison ===")
    print(f"Budget without policy: {budget_without_policy:.2f}")
    print(f"Budget with policy: {budget_with_policy:.2f}")
    print(f"Improvement: {(budget_without_policy - budget_with_policy) / budget_without_policy * 100:.2f}%")    
    print("\n=== Hephaestus Framework Complete ===")

if __name__ == "__main__":
    main()