"""
Neural network architectures for learning discrete gradient flow dynamics.

This module contains neural network models that learn potential functions (V)
and entropy coefficients (beta) that drive discrete Fokker-Planck dynamics.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class VAndBetaMLP(nn.Module):
    """
    Multi-layer perceptron for learning both potential V and entropy coefficient beta.
    
    Args:
        num_nodes: Number of nodes in the graph
        embedding_dim: Dimension of node embeddings
        hidden_dim: Dimension of hidden layers
    """
    def __init__(self, num_nodes, embedding_dim=16, hidden_dim=64):
        super().__init__()
        self.embedding = nn.Embedding(num_nodes, embedding_dim)
        self.shared = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
        self.v_head = nn.Linear(hidden_dim, 1)
        self.beta_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1), nn.Softplus()
        )

    def forward(self, node_idx):
        """
        Forward pass computing V and beta for given nodes.
        
        Args:
            node_idx: Tensor of node indices
            
        Returns:
            V: Potential values for each node (N,)
            beta_nodes: Per-node beta estimates (N,)
        """
        h = self.shared(self.embedding(node_idx))
        V = self.v_head(h).squeeze(-1)
        beta_nodes = self.beta_head(h).squeeze(-1)
        return V, beta_nodes


class EnergyNetwork(nn.Module):
    """
    Energy-based network with gradient of V and global beta parameter.
    
    Args:
        num_nodes: Number of nodes in the graph
        embedding_dim: Dimension of node embeddings
        hidden_dim: Dimension of hidden layers
    """
    def __init__(self, num_nodes, embedding_dim=16, hidden_dim=64):
        super().__init__()
        self.embedding = nn.Embedding(num_nodes, embedding_dim)
        self.shared = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
        self.v_head = nn.Linear(hidden_dim, num_nodes)
        self.beta = nn.Parameter(torch.tensor(0.0, dtype=torch.float64))

    def forward(self, node_idx):
        """
        Forward pass computing gradient of V and beta.
        
        Args:
            node_idx: Tensor of node indices
            
        Returns:
            gradV: Gradient of potential (N,)
            beta: Global entropy coefficient (scalar)
        """
        h = self.shared(self.embedding(node_idx))
        gradV = self.v_head(h)
        return gradV, self.beta

    @torch.no_grad()
    def get_potential(self):
        """Get potential and beta for all nodes without gradients."""
        all_nodes = torch.arange(self.embedding.num_embeddings, device=self.beta.device)
        gradV, beta = self.forward(all_nodes)
        return gradV, beta
