import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import networkx as nx
from torch_geometric.nn import GATConv
import random

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Global parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class GraphAttentionLayer(nn.Module):
    """
    Graph attention layer with multiple attention heads
    """
    def __init__(self, in_features, out_features, n_heads, dropout=0.6, alpha=0.2, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.n_heads = n_heads
        self.concat = concat
        
        # Using PyTorch Geometric's GATConv for efficient implementation
        self.gat_conv = GATConv(
            in_features, 
            out_features, 
            heads=n_heads, 
            dropout=dropout,
            negative_slope=alpha,
            concat=concat
        )
        
    def forward(self, x, edge_index):
        """
        x: Node features [N, in_features]
        edge_index: Graph connectivity [2, E]
        """
        return self.gat_conv(x, edge_index)


class SPAGAN(nn.Module):
    """
    Shortest Path Graph Attention Network (SPAGAN)
    Predicts shortest path costs between node pairs
    """
    def __init__(self, input_dim=3, hidden_dim=64, n_layers=4, n_heads=4, dropout=0.1, alpha=0.2):
        super(SPAGAN, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dropout = dropout
        
        # Input projection
        self.input_projection = nn.Linear(input_dim, hidden_dim)
        
        # Node embedding layers - keeping track of output dimensions carefully
        self.gat_layers = nn.ModuleList()
        
        # First layer takes hidden_dim input and produces hidden_dim * n_heads output when concat=True
        self.gat_layers.append(
            GraphAttentionLayer(hidden_dim, hidden_dim, n_heads, dropout, alpha, concat=True)
        )
        
        # Middle layers take hidden_dim * n_heads as input and produce the same output dimension
        for i in range(n_layers - 2):
            self.gat_layers.append(
                GraphAttentionLayer(hidden_dim * n_heads, hidden_dim, n_heads, dropout, alpha, concat=True)
            )
        
        # Last layer to produce final node embeddings - use concat=False to reduce dimensions
        self.gat_layers.append(
            GraphAttentionLayer(hidden_dim * n_heads, hidden_dim, n_heads, dropout, alpha, concat=False)
        )
        
        # Source-target feature projection
        self.st_projection = nn.Linear(2 * hidden_dim, hidden_dim)
        
        # Path cost predictor takes hidden_dim from the final node embeddings
        self.path_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1),
            nn.ReLU()  # Ensure non-negative path costs
        )
    
    def forward(self, data, source_indices, target_indices, perturbation=None):
        """
        Predicts the shortest path distance between source and target nodes.
        
        data: PyTorch Geometric data object containing the graph
        source_indices: Indices of source nodes [batch_size]
        target_indices: Indices of target nodes [batch_size]
        perturbation: Edge weight perturbations [num_edges] or None
        """
        x, edge_index = data.x, data.edge_index
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
        
        # Apply perturbation to edge weights if provided
        if perturbation is not None and edge_attr is not None:
            # Ensure perturbation is on the same device as edge_attr
            perturbation = perturbation.to(edge_attr.device)
            edge_weights = edge_attr[:, 0] + perturbation
        
        # Initial node embeddings
        x = self.input_projection(x)
        
        # Apply GAT layers
        for i, gat_layer in enumerate(self.gat_layers):
            x = gat_layer(x, edge_index)
            if i < self.n_layers - 1:
                x = F.elu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Final node embeddings should be [num_nodes, hidden_dim]
        
        # Extract source and target embeddings
        source_embeddings = x[source_indices]  # [batch_size, hidden_dim]
        target_embeddings = x[target_indices]  # [batch_size, hidden_dim]
        
        # Combine source and target features
        combined_features = torch.cat([source_embeddings, target_embeddings], dim=1)
        path_features = self.st_projection(combined_features)  # [batch_size, hidden_dim]
        
        # Predict path cost
        path_cost = self.path_predictor(path_features).squeeze(-1)
        
        return path_cost
    
    def predict_path(self, data, source_idx, target_idx, perturbation=None):
        """
        Identifies the likely shortest path between source and target.
        Returns the predicted path as a list of node indices
        """
        # Convert to NetworkX for path extraction
        G = nx.DiGraph()
        
        # Move data to CPU for NetworkX
        edge_index = data.edge_index.cpu()
        edge_attr = data.edge_attr.cpu() if hasattr(data, 'edge_attr') else None
        
        # Apply perturbation to edge weights if provided
        if perturbation is not None and edge_attr is not None:
            # Ensure perturbation is on CPU for NetworkX
            perturbation = perturbation.cpu()
            edge_weights = edge_attr[:, 0] + perturbation
        else:
            edge_weights = edge_attr[:, 0] if edge_attr is not None else torch.ones(edge_index.size(1))
        
        # Add nodes
        for i in range(data.x.size(0)):
            G.add_node(i)
            
        # Add edges with weights
        for i in range(edge_index.size(1)):
            src, dst = edge_index[0, i].item(), edge_index[1, i].item()
            weight = edge_weights[i].item() if edge_weights is not None else 1.0
            G.add_edge(src, dst, weight=weight)
        
        # Predicted shortest path using Dijkstra's algorithm
        source_idx_cpu = source_idx.cpu().item() if isinstance(source_idx, torch.Tensor) else source_idx
        target_idx_cpu = target_idx.cpu().item() if isinstance(target_idx, torch.Tensor) else target_idx
        
        try:
            path = nx.shortest_path(G, source=source_idx_cpu, target=target_idx_cpu, weight='weight')
            return path
        except nx.NetworkXNoPath:
            return []  # No path exists


class Swish(nn.Module):
    """
    Swish activation function: x * sigmoid(x)
    """
    def forward(self, x):
        return x * torch.sigmoid(x)

class EnergyModel(nn.Module):
    """
    Energy-Based Model (EBM) to approximate solution distribution
    """
    def __init__(self, input_dim, hidden_dim=512, num_layers=6):
        super(EnergyModel, self).__init__()
        
        layers = []
        current_dim = input_dim
        
        # Build hidden layers
        for i in range(num_layers - 1):
            layers.append(nn.Linear(current_dim, hidden_dim))
            layers.append(Swish())  # Using Swish activation
            current_dim = hidden_dim
        
        # Output layer (produces scalar energy value)
        layers.append(nn.Linear(current_dim, 1))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        """
        Compute energy value for input x
        Lower energy = higher probability
        """
        return self.network(x)
    
    
class ConditionalVAE(nn.Module):
    """
    Conditional VAE expert for a single mode
    """
    def __init__(self, input_dim, condition_dim, latent_dim=128, hidden_dim=512):
        super(ConditionalVAE, self).__init__()
        
        # Encoder (inference network)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + condition_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2)
        )
        
        # Mean and log variance for latent distribution
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder (generative network)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + condition_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.ReLU()  # Ensure non-negative outputs for perturbation
        )
        
        self.latent_dim = latent_dim
    
    def encode(self, x, condition):
        """Encode input to latent distribution parameters"""
        # Concatenate input and condition
        x_cond = torch.cat([x, condition], dim=1)
        h = self.encoder(x_cond)
        
        # Get distribution parameters
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        
        return mu, logvar
    
    def decode(self, z, condition):
        """Decode latent variable to output"""
        # Concatenate latent and condition
        z_cond = torch.cat([z, condition], dim=1)
        return self.decoder(z_cond)
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick for backpropagation through sampling"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x, condition):
        """Full forward pass through CVAE"""
        mu, logvar = self.encode(x, condition)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, condition), mu, logvar
    
    def sample(self, condition, num_samples=1):
        """Sample from the prior and decode"""
        # Sample from prior N(0, I)
        z = torch.randn(num_samples, self.latent_dim, device=condition.device)
        
        # Expand condition if needed
        if num_samples > 1 and condition.size(0) == 1:
            condition = condition.repeat(num_samples, 1)
        
        # Decode
        return self.decode(z, condition)


class MixtureOfCVAE(nn.Module):
    """
    Mixture of Conditional VAEs (Mix-CVAE)
    """
    def __init__(self, input_dim, condition_dim, num_experts=1, latent_dim=128, hidden_dim=512):
        super(MixtureOfCVAE, self).__init__()
        
        self.num_experts = num_experts
        self.input_dim = input_dim
        self.condition_dim = condition_dim
        self.latent_dim = latent_dim
        
        # Create initial experts
        self.experts = nn.ModuleList([
            ConditionalVAE(input_dim, condition_dim, latent_dim, hidden_dim)
            for _ in range(num_experts)
        ])
        
        # Expert gating network (simple MLP for weighting experts)
        self.gating_network = nn.Sequential(
            nn.Linear(input_dim + condition_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, num_experts),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x, condition):
        """
        Forward pass through the mixture model
        Returns reconstruction, mu, logvar, and expert weights
        """
        # Get expert weights from gating network
        x_cond = torch.cat([x, condition], dim=1)
        expert_weights = self.gating_network(x_cond)
        
        # Process through each expert
        expert_outputs = []
        mus = []
        logvars = []
        
        for i, expert in enumerate(self.experts):
            output, mu, logvar = expert(x, condition)
            expert_outputs.append(output)
            mus.append(mu)
            logvars.append(logvar)
        
        # Stack expert outputs
        stacked_outputs = torch.stack(expert_outputs, dim=1)  # [batch, num_experts, input_dim]
        
        # Weight expert outputs
        weighted_sum = torch.sum(stacked_outputs * expert_weights.unsqueeze(2), dim=1)
        
        return weighted_sum, mus, logvars, expert_weights
    
    def sample(self, condition, num_samples=1):
        """Sample from the mixture model"""
        if num_samples == 1:
            # For single sample, use a random expert
            expert_idx = random.randint(0, self.num_experts - 1)
            expert = self.experts[expert_idx]
            return expert.sample(condition, num_samples)
        else:
            # Distribute samples across experts
            samples = []
            samples_per_expert = num_samples // self.num_experts
            remainder = num_samples % self.num_experts
            
            # Track the current index in the condition tensor
            current_idx = 0
            
            for i, expert in enumerate(self.experts):
                n_samples = samples_per_expert + (1 if i < remainder else 0)
                if n_samples > 0:
                    # Slice the condition tensor to match n_samples
                    condition_slice = condition[current_idx:current_idx + n_samples]
                    expert_samples = expert.sample(condition_slice, n_samples)
                    samples.append(expert_samples)
                    current_idx += n_samples
            
            return torch.cat(samples, dim=0)
    
    def add_expert(self):
        """Add a new expert to the mixture"""
        new_expert = ConditionalVAE(
            self.input_dim, 
            self.condition_dim, 
            self.latent_dim
        ).to(device)
        
        # Add to expert list
        self.experts.append(new_expert)
        self.num_experts += 1
        
        # The input dimension for the new linear layer should be hidden_dim
        hidden_dim = 512  # Default hidden_dim from __init__
        
        # Recreate gating network, replacing the last linear layer
        new_layers = []
        for layer in self.gating_network[:-2]:  # Copy all layers except the last linear and Softmax
            new_layers.append(layer)
        
        # Add new linear layer with correct input dimension
        new_layers.append(nn.Linear(hidden_dim, self.num_experts))
        new_layers.append(nn.Softmax(dim=1))
        
        self.gating_network = nn.Sequential(*new_layers).to(device)
        
        return new_expert