import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsde
import torchdiffeq
from torch_geometric.nn import GCNConv

class SimpleSpatialGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.conv3 = GCNConv(out_channels, out_channels)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        x = self.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        x = self.conv3(x, edge_index)
        return x



class GraphNeuralSDE(torchsde.SDEIto):
    def __init__(self, in_channels=2, 
                 hidden_channels=64, num_classes=2,
                 covariance_matrix=None):
        """
        Neural SDE with GNN-based drift and diffusion networks and classification head.
        """
        super().__init__(noise_type="general")
        
        self.encoder_mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
        )

        # Drift and diffusion networks
        self.drift_net = SimpleSpatialGNN(
            in_channels=hidden_channels,
            hidden_channels=hidden_channels,
            out_channels=hidden_channels
        )
        
        self.diffusion_net = SimpleSpatialGNN(
            in_channels=hidden_channels,
            hidden_channels=hidden_channels,
            out_channels=hidden_channels
        )
      
        self.register_buffer('edge_index', None)
        # linear classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, num_classes)
        )
        self.covariance_matrix = covariance_matrix
    
    def set_edge_index(self, edge_index):
        self.edge_index = edge_index

    def f(self, t, y):
        """Drift function"""
        return self.drift_net(y, self.edge_index)

    def g(self, t, y):
        """Diffusion function"""
        if self.covariance_matrix is None:
            return self.diffusion_net(y, self.edge_index)
        else:
            return self.diffusion_net(y, self.edge_index) @ self.covariance_matrix

    def sample_trajectory(self, x0, ts):
        """Sample trajectories from the SDE"""
        if self.edge_index is None:
            raise ValueError("Edge index not set. Call set_edge_index first.")
        xs = torchsde.sdeint(
            self,
            x0,
            ts,
            method='srk',
            dt=1e-2,
            adaptive=False,
            rtol=1e-3,
            atol=1e-3,
        )
        return xs

    def forward(self, x, ts):
        """Forward pass using all nodes for message passing"""
        z = self.encoder_mlp(x)
        trajectories = self.sample_trajectory(z, ts)
        final_state = trajectories[-1]
        logits = self.classifier(final_state)
        return logits, final_state
    
    def compute_loss(self, logits, y, mask=None, lambda_joint=0):
        """Compute loss for masked nodes"""
        if mask is not None:
            loss_cv = F.cross_entropy(logits[mask], y[mask])
        else:
            loss_cv = F.cross_entropy(logits, y)
        loss_joint = -torch.mean(torch.sum(logits, dim=-1))
        loss = loss_cv + lambda_joint * loss_joint
        return loss, loss_cv, loss_joint
    
    def energy_function(self, latent, return_logits=False):
        """Compute energy from latent representation"""
        logits = self.classifier(latent) # classifier is a ebm
        
        # Energy is negative log-sum-exp of logits
        exp_logits = torch.exp(logits)
        energy = -torch.log(torch.sum(exp_logits, dim=1))
        
        if return_logits:
            return energy, logits
        return energy

    # obtain expected entropy and variance of predictive distribution
    def compute_expected_entropy_and_variance(self, x, ts, n_trajectories=20):
        entropies_ensemble = []
        variances_ensemble = []
        for _ in range(n_trajectories):
            logits, _ = self.forward(x, ts) # logits: (n_nodes, n_classes), final_state: (n_nodes, n_features)
            probabilities = F.softmax(logits, dim=1) # (n_nodes, n_classes)
            entropy = -torch.sum(probabilities * torch.log(probabilities), dim=1) # (n_nodes,)
            variances = torch.var(logits, dim=0) # (n_nodes,)
            entropies_ensemble.append(entropy)
            variances_ensemble.append(variances)
        # average over ensemble
        entropies = torch.mean(torch.stack(entropies_ensemble, dim=0), dim=0)
        variances = torch.mean(torch.stack(variances_ensemble, dim=0), dim=0)
        return entropies, variances # (n_nodes,), (n_nodes,)




class GraphNeuralODE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, 
                 covariance_matrix=None, noise_scale=1.0):
        """
        Neural ODE with GNN-based vector field network incorporating Gaussian noise.
        
        Args:
            in_channels: Dimension of input features
            hidden_channels: Dimension of hidden layers
            out_channels: Dimension of output features
            covariance_matrix: Covariance matrix L for noise transformation (default: None)
            noise_scale: Scaling factor for the noise (default: 1.0)
        """
        super().__init__()
        
        # Two MLP encoders for deterministic and stochastic parts
        self.encoder_F = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
        )
        
        self.encoder_G = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
        )
        
        # GNN for spatial message passing
        self.gnn = SimpleSpatialGNN(
            in_channels=hidden_channels,
            hidden_channels=hidden_channels,
            out_channels=out_channels
        )
        
        # Store the covariance matrix (can be None initially)
        if covariance_matrix is not None and not isinstance(covariance_matrix, torch.Tensor):
            covariance_matrix = torch.tensor(covariance_matrix, dtype=torch.float32)
        self.register_buffer('covariance_matrix', covariance_matrix)
        
        self.noise_scale = noise_scale
        self.hidden_channels = hidden_channels
        self.register_buffer('edge_index', None)
        
        # Classifier head
        self.classifier = nn.Sequential(
            nn.Linear(out_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 2)  # Assuming binary classification
        )
    
    def set_edge_index(self, edge_index):
        """Set the edge index for message passing"""
        self.edge_index = edge_index
    
    def set_covariance_matrix(self, covariance_matrix):
        """Set or update the covariance matrix L"""
        if not isinstance(covariance_matrix, torch.Tensor):
            covariance_matrix = torch.tensor(covariance_matrix, dtype=torch.float32)
        self.register_buffer('covariance_matrix', covariance_matrix)
    
    def vector_field(self, t, x):
        """
        Vector field function for the ODE solver.
        Computes: H(x) = F(x) + G(x) * L * Z, then passes through GNN
        """
        if self.edge_index is None:
            raise ValueError("Edge index not set. Call set_edge_index first.")
        
        batch_size = x.shape[0]
        
        # Deterministic part: F(x)
        f_x = self.encoder_F(x)
        
        # Stochastic part: G(x) * L * Z
        g_x = self.encoder_G(x)
        
        # Generate noise
        Z = torch.randn(batch_size, self.hidden_channels, device=x.device)
        
        if self.covariance_matrix is not None:
            # Apply covariance matrix: L * Z
            if self.covariance_matrix.dim() == 2:
                # If L is a full matrix, use matrix multiplication
                L_Z = Z @ self.covariance_matrix
            else:
                # If L is a vector (diagonal), use element-wise multiplication
                L_Z = Z * self.covariance_matrix
        else:
            # Without covariance matrix, just use Z scaled by noise_scale
            L_Z = Z * self.noise_scale
        
        # Combine F(x) and G(x) * L * Z
        h_x = f_x + g_x * L_Z
        
        # Pass through GNN for spatial message passing
        output = self.gnn(h_x, self.edge_index)
        
        return output
    
    def integrate(self, x0, ts):
        """
        Solve the ODE from initial state x0 using the vector field function.
        
        Args:
            x0: Initial state tensor of shape [batch_size, in_channels]
            ts: Time points tensor at which to evaluate the solution
            
        Returns:
            Tensor of shape [len(ts), batch_size, out_channels]
        """
        solution = torchdiffeq.odeint(
            self.vector_field,
            x0,
            ts,
            method='dopri5',
            rtol=1e-3,
            atol=1e-3,
        )
        return solution
    
    def forward(self, x, ts):
        """
        Forward pass through the neural ODE.
        
        Args:
            x: Input features tensor of shape [batch_size, in_channels]
            ts: Time points tensor at which to evaluate the solution
            
        Returns:
            Tuple of (logits, final_state)
        """
        # Solve ODE
        trajectory = self.integrate(x, ts)
        
        # Extract final state
        final_state = trajectory[-1]
        
        # Apply classifier to final state
        logits = self.classifier(final_state)
        
        return logits, final_state
    
    def compute_loss(self, logits, y, mask=None, lambda_joint=0):
        """Compute loss for masked nodes"""
        if mask is not None:
            loss_cv = F.cross_entropy(logits[mask], y[mask])
        else:
            loss_cv = F.cross_entropy(logits, y)
        loss_joint = -torch.mean(torch.sum(logits, dim=-1))
        loss = loss_cv + lambda_joint * loss_joint
        return loss, loss_cv, loss_joint
    
    def energy_function(self, latent, return_logits=False):
        """Compute energy from latent representation"""
        logits = self.classifier(latent)
        
        # Energy is negative log-sum-exp of logits
        exp_logits = torch.exp(logits)
        energy = -torch.log(torch.sum(exp_logits, dim=1))
        
        if return_logits:
            return energy, logits
        return energy
    
    def compute_expected_entropy_and_variance(self, x, ts, n_trajectories=20):
        """
        Compute expected entropy and variance of predictions across multiple trajectories.
        
        Args:
            x: Input features tensor
            ts: Time points tensor
            n_trajectories: Number of trajectory samples to use
            
        Returns:
            Tuple of (entropies, variances)
        """
        entropies_ensemble = []
        variances_ensemble = []
        
        for _ in range(n_trajectories):
            logits, _ = self.forward(x, ts)
            probabilities = F.softmax(logits, dim=1)
            entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-10), dim=1)
            variances = torch.var(logits, dim=1)
            
            entropies_ensemble.append(entropy)
            variances_ensemble.append(variances)
        
        # Average over ensemble
        entropies = torch.mean(torch.stack(entropies_ensemble, dim=0), dim=0)
        variances = torch.mean(torch.stack(variances_ensemble, dim=0), dim=0)
        
        return entropies, variances