from typing import List
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
import math

from src.models.transformer_model import GraphTransformer

from torch_scatter import scatter_add
from torch_sparse import SparseTensor, matmul, fill_diag, sum, mul_
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
             add_self_loops=True, dtype=None):

    fill_value = 2. if improved else 1.

    if isinstance(edge_index, SparseTensor):
        adj_t = edge_index
        if not adj_t.has_value():
            adj_t.fill_value(1., dtype=dtype)
        if add_self_loops:
            adj_t = fill_diag(adj_t, fill_value)
        deg = sum(adj_t, dim=1)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
        adj_t = mul_(adj_t, deg_inv_sqrt.view(-1, 1))
        adj_t = mul_(adj_t, deg_inv_sqrt.view(1, -1))
        return adj_t

    else:
        num_nodes = maybe_num_nodes(edge_index, num_nodes)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)

        if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value, num_nodes)
            assert tmp_edge_weight is not None
            edge_weight = tmp_edge_weight

        row, col = edge_index[0], edge_index[1]
        deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

def build_mlp(
        input_size: int,
        hidden_layer_sizes: List[int],
        output_size: int = None,
        output_activation: nn.Module = nn.Identity,
        activation: nn.Module = nn.ReLU) -> nn.Module:
  """Build a MultiLayer Perceptron.

  Args:
    input_size: Size of input layer.
    layer_sizes: An array of input size for each hidden layer.
    output_size: Size of the output layer.
    output_activation: Activation function for the output layer.
    activation: Activation function for the hidden layers.

  Returns:
    mlp: An MLP sequential container.
  """
  # Size of each layer
  layer_sizes = [input_size] + hidden_layer_sizes
  if output_size:
    layer_sizes.append(output_size)

  # Number of layers
  nlayers = len(layer_sizes) - 1

  # Create a list of activation functions and
  # set the last element to output activation function
  act = [activation for i in range(nlayers)]
  act[-1] = output_activation

  # Create a torch sequential container
  mlp = nn.Sequential()
  for i in range(nlayers):
    mlp.add_module("NN-" + str(i), nn.Linear(layer_sizes[i],
                                             layer_sizes[i + 1]))
    mlp.add_module("Act-" + str(i), act[i]())

  return mlp


class Encoder(nn.Module):
  """Graph network encoder. Encode nodes and edges states to an MLP. The Encode:
  :math: `\mathcal{X} \rightarrow \mathcal{G}` embeds the particle-based state
  representation, :math: `\mathcal{X}`, as a latent graph, :math:
  `G^0 = encoder(\mathcal{X})`, where :math: `G = (V, E, u), v_i \in V`, and
  :math: `e_{i,j} in E`
  """

  def __init__(
          self,
          nnode_in_features: int,
          nnode_out_features: int,
          nedge_in_features: int,
          nedge_out_features: int,
          nmlp_layers: int,
          mlp_hidden_dim: int):
          # t_in_features: int,
          # t_out_features: int,
          # mlp_hidden_t: int
    """The Encoder implements nodes features :math: `\varepsilon_v` and edge
    features :math: `\varepsilon_e` as multilayer perceptrons (MLP) into the
    latent vectors, :math: `v_i` and :math: `e_{i,j}`, of size 128.

    Args:
      nnode_in_features: Number of node input features (for 2D = 30, calculated
        as [10 = 5 times steps * 2 positions (x, y) +
        4 distances to boundaries (top/bottom/left/right) +
        16 particle type embeddings]).
      nnode_out_features: Number of node output features (latent dimension of
        size 128).
      nedge_in_features: Number of edge input features (for 2D = 3, calculated
        as [2 (x, y) relative displacements between 2 particles + distance
        between 2 particles]).
      nedge_out_features: Number of edge output features (latent dimension of
        size 128).
      nmlp_layer: Number of hidden layers in the MLP (typically of size 2).
      mlp_hidden_dim: Size of the hidden layer (latent dimension of size 128).

    """
    super(Encoder, self).__init__()
    # Encode node features as an MLP
    self.node_fn = nn.Sequential(*[build_mlp(nnode_in_features,
                                             [mlp_hidden_dim
                                              for _ in range(nmlp_layers)],
                                             nnode_out_features),
                                   nn.LayerNorm(nnode_out_features)])
    # Encode edge features as an MLP
    self.edge_fn = nn.Sequential(*[build_mlp(nedge_in_features,
                                             [mlp_hidden_dim
                                              for _ in range(nmlp_layers)],
                                             nedge_out_features),
                                   nn.LayerNorm(nedge_out_features)])
    # # Encode timestep features as an MLP
    # self.t_fn = nn.Sequential(*[build_mlp(t_in_features,
    #                                          [mlp_hidden_t
    #                                           for _ in range(nmlp_layers)],
    #                                          t_out_features),
    #                                nn.LayerNorm(t_out_features)])


  def forward(
          self,
          x: torch.tensor,
          edge_features: torch.tensor):
    """The forward hook runs when the Encoder class is instantiated

    Args:
      x: Particle state representation as a torch tensor with shape
        (nparticles, nnode_input_features)
      edge_features: Edge features as a torch tensor with shape
        (nparticles, nedge_input_features)

    """
    return self.node_fn(x), self.edge_fn(edge_features)


class InteractionNetwork(MessagePassing):
  def __init__(
      self,
      nnode_in: int,
      nnode_out: int,
      nedge_in: int,
      nedge_out: int,
      nmlp_layers: int,
      mlp_hidden_dim: int,
      timestep_embed_dim: int,
  ):
    """InteractionNetwork derived from torch_geometric MessagePassing class

    Args:
      nnode_in: Number of node inputs (latent dimension of size 128).
      nnode_out: Number of node outputs (latent dimension of size 128).
      nedge_in: Number of edge inputs (latent dimension of size 128).
      nedge_out: Number of edge output features (latent dimension of size 128).
      nmlp_layer: Number of hidden layers in the MLP (typically of size 2).
      mlp_hidden_dim: Size of the hidden layer (latent dimension of size 128).

    """
    # Aggregate features from neighbors
    super(InteractionNetwork, self).__init__(aggr='add')
    # Node MLP
    self.node_fn = nn.Sequential(*[build_mlp(nnode_in + nedge_out + timestep_embed_dim,
                                             [mlp_hidden_dim
                                              for _ in range(nmlp_layers)],
                                             nnode_out),
                                   nn.LayerNorm(nnode_out)])
    # Edge MLP
    self.edge_fn = nn.Sequential(*[build_mlp(nnode_in + nnode_in + nedge_in + timestep_embed_dim,
                                             [mlp_hidden_dim
                                              for _ in range(nmlp_layers)],
                                             nedge_out),
                                   nn.LayerNorm(nedge_out)])
    #timestep
    self.timestep_embed = nn.Sequential(
        nn.Linear(1, timestep_embed_dim),
        nn.SiLU(),
        nn.Linear(timestep_embed_dim, timestep_embed_dim)
    )

    self.weight_node1 = Parameter(torch.Tensor(nnode_in, nnode_out))
    self.weight_node2 = Parameter(torch.Tensor(nnode_in, nnode_out))
    self.weight_edge1 = Parameter(torch.Tensor(nedge_in, nedge_out))
    self.weight_edge2 = Parameter(torch.Tensor(nedge_in, nedge_out))

    self.reset_parameters()
    self.normalize = False

  def reset_parameters(self):
    glorot(self.weight_node1)
    glorot(self.weight_node2)
    glorot(self.weight_edge1)
    glorot(self.weight_edge2)

  def forward(self,
              x: torch.tensor,
              edge_index: torch.tensor,
              edge_features: torch.tensor,
              timestep: torch.tensor,
              alpha,
              beta,
              x0,
              edge0):
    """The forward hook runs when the InteractionNetwork class is instantiated

    Args:
      x: Particle state representation as a torch tensor with shape
        (nparticles, nnode_input_features)
      edge_index: A torch tensor list of source and target nodes with shape
        (2, nedges)
      edge_features: Edge features as a torch tensor with shape
        (nedges, nedge_in=latent_dim of 128)

    Returns:
      tuple: Updated node and edge features
    """
    # encode timestep
    # import pdb; pdb.set_trace()
    timestep_embed = self.timestep_embed(timestep)
    # Save particle state and edge features
    # x_residual = x
    # edge_features_residual = edge_features
    # Start propagating messages.
    # Takes in the edge indices and all additional data which is needed to
    # construct messages and to update node embeddings.
    # Call PyG propagate() method:
    # 1. Message phase - compute messages for each edge
    # 2. Aggregate phase - aggregate messages for each node
    # 3. Update phase - updates only the node features
    # Update uses the message from step 1 and any original arguments passed to 
    # propagate() to update the node embeddings. This is why we need to store
    # the updated edge features to return them from the update() method.
    if self.normalize:
      edge_index, edge_weight = gcn_norm(  # yapf: disable
            edge_index, None, x.shape[0],
            dtype=x.dtype)
    
    support_node = (1-beta)*(1-alpha)*x + beta*torch.matmul(x, self.weight_node1)
    initial_node = (1-beta)*(alpha)*x0 + beta*torch.matmul(x0, self.weight_node2)
    support_edge = (1-beta)*(1-alpha)*edge_features + beta*torch.matmul(edge_features, self.weight_edge1)
    initial_edge = (1-beta)*(alpha)*edge0 + beta*torch.matmul(edge0, self.weight_edge2)

    x, edge_features = self.propagate(
        edge_index=edge_index, x=support_node, edge_features=support_edge, timestep_embed=timestep_embed)

    return x + initial_node, edge_features + initial_edge

  def message(self,
              x_i: torch.tensor,
              x_j: torch.tensor,
              edge_features: torch.tensor,
              timestep_embed: torch.tensor) -> torch.tensor:
    """Constructs message from j to i of edge :math:`e_{i, j}`. Tensors :obj:`x`
    passed to :meth:`propagate` can be mapped to the respective nodes :math:`i`
    and :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name,
    i.e., :obj:`x_i` and :obj:`x_j`.

    Args:
      x_i: Particle state representation as a torch tensor with shape
        (nparticles, nnode_in=latent_dim of 128) at node i
      x_j: Particle state representation as a torch tensor with shape
        (nparticles, nnode_in=latent_dim of 128) at node j
      edge_features: Edge features as a torch tensor with shape
        (nedges, nedge_in=latent_dim of 128)

    """
    # import pdb; pdb.set_trace()
    timestep_embed = timestep_embed.expand(edge_features.shape[0], -1)
    # Concat edge features with a final shape of [nedges, latent_dim*3]
    edge_features = torch.cat([x_i, x_j, edge_features, timestep_embed], dim=-1)
    self._edge_features = self.edge_fn(edge_features)  # Create and store
    return self._edge_features  # This gets passed to aggregate()

  def update(self,
             x_updated: torch.tensor,
             x: torch.tensor,
             edge_features: torch.tensor,
             timestep_embed: torch.tensor):
    """Update the particle state representation

    Args:
      x: Particle state representation as a torch tensor with shape 
        (nparticles, nnode_in=latent_dim of 128)
      x_updated: Updated particle state representation as a torch tensor with 
        shape (nparticles, nnode_in=latent_dim of 128)
      edge_features: Edge features as a torch tensor with shape 
        (nedges, nedge_out=latent_dim of 128)

    Returns:
      tuple: Updated node and edge features
    """
    # Concat node features with a final shape of
    # [nparticles, latent_dim (or nnode_in) *2]
    # This gets called later, after message() and aggregate()
    # Update modified from MessagePassing takes the output of aggregation
    # as first argument and any argument which was initially passed to
    # propagate hence we need to return the stored value of edge_features
    # import pdb; pdb.set_trace()
    timestep_embed = timestep_embed.expand(x.shape[0], -1)
    x_updated = torch.cat([x_updated, x, timestep_embed], dim=-1)
    x_updated = self.node_fn(x_updated)
    return x_updated, self._edge_features


class Processor(MessagePassing):
  """The Processor: :math: `\mathcal{G} \rightarrow \mathcal{G}` computes 
  interactions among nodes via :math: `M` steps of learned message-passing, to 
  generate a sequence of updated latent graphs, :math: `G = (G_1 , ..., G_M )`, 
  where :math: `G^{m+1| = GN^{m+1} (G^m )`. It returns the final graph, 
  :math: `G^M = PROCESSOR(G^0)`. Message-passing allows information to 
  propagate and constraints to be respected: the number of message-passing 
  steps required will likely scale with the complexity of the interactions.

  """

  def __init__(
      self,
      nnode_in: int,
      nnode_out: int,
      nedge_in: int,
      nedge_out: int,
      nmessage_passing_steps: int,
      nmlp_layers: int,
      mlp_hidden_dim: int,
      timestep_embed_dim: int,
  ):
    """Processor derived from torch_geometric MessagePassing class. The 
    processor uses a stack of :math: `M GNs` (where :math: `M` is a 
    hyperparameter) with identical structure, MLPs as internal edge and node 
    update functions, and either shared or unshared parameters. We use GNs 
    without global features or global updates (i.e., an interaction network), 
    and with a residual connections between the input and output latent node 
    and edge attributes.

    Args:
      nnode_in: Number of node inputs (latent dimension of size 128).
      nnode_out: Number of node outputs (latent dimension of size 128).
      nedge_in: Number of edge inputs (latent dimension of size 128).
      nedge_out: Number of edge output features (latent dimension of size 128).
      nmessage_passing_steps: Number of message passing steps.
      nmlp_layer: Number of hidden layers in the MLP (typically of size 2).
      mlp_hidden_dim: Size of the hidden layer (latent dimension of size 128).

    """
    super(Processor, self).__init__(aggr='max')
    # Create a stack of M Graph Networks GNs.
    self.gnn_stacks = nn.ModuleList([
        InteractionNetwork(
            nnode_in=nnode_in,
            nnode_out=nnode_out,
            nedge_in=nedge_in,
            nedge_out=nedge_out,
            nmlp_layers=nmlp_layers,
            mlp_hidden_dim=mlp_hidden_dim,
            timestep_embed_dim=timestep_embed_dim,
        ) for _ in range(nmessage_passing_steps)])
    
    self.alpha = 0.1
    self.lamda = 0.5

  def forward(self,
              x: torch.tensor,
              edge_index: torch.tensor,
              edge_features: torch.tensor,
              timestep: torch.tensor,):
    """The forward hook runs through GNN stacks when class is instantiated. 

    Args:
      x: Particle state representation as a torch tensor with shape 
        (nparticles, latent_dim)
      edge_index: A torch tensor list of source and target nodes with shape 
        (2, nedges)
      edge_features: Edge features as a torch tensor with shape 
        (nparticles, latent_dim)

    """
    x0 = x
    edge_features0 = edge_features
    for i, gnn in enumerate(self.gnn_stacks):
      beta = math.log(self.lamda/(i+1)+1)
      x, edge_features = gnn(x, edge_index, edge_features, timestep, self.alpha, beta, x0, edge_features0)
    return x, edge_features


class Processor_attn(MessagePassing):
  def __init__(
      self,
      nnode_in: int,
      nnode_out: int,
      nedge_in: int,
      nedge_out: int,
      nmessage_passing_steps: int,
      nmlp_layers: int,
      mlp_hidden_dim: int,
      timestep_embed_dim: int,
  ):
    super(Processor_attn, self).__init__(aggr='max')
    # Create a stack of M Graph Networks GNs.
    self.gnn_stacks = nn.ModuleList([
        InteractionNetwork(
            nnode_in=nnode_in,
            nnode_out=nnode_out,
            nedge_in=nedge_in,
            nedge_out=nedge_out,
            nmlp_layers=nmlp_layers,
            mlp_hidden_dim=mlp_hidden_dim,
            timestep_embed_dim=timestep_embed_dim,
        ) for _ in range(nmessage_passing_steps)])
    
    self.step_weights = nn.Parameter(torch.ones(nmessage_passing_steps))
    self.edge_weights = nn.Parameter(torch.ones(nmessage_passing_steps))

  def forward(self,
              x: torch.tensor,
              edge_index: torch.tensor,
              edge_features: torch.tensor,
              timestep: torch.tensor,):

    node_outputs = []
    edge_outputs = []
        
    for gnn in self.gnn_stacks:
      x, edge_features = gnn(x, edge_index, edge_features, timestep)
      node_outputs.append(x)
      edge_outputs.append(edge_features)
    # breakpoint()
    step_weights = torch.softmax(self.step_weights, dim=0)
    edge_weights = torch.softmax(self.edge_weights, dim=0)  
   
    final_nodes = torch.stack(node_outputs) * step_weights.view(-1,1,1)
    final_nodes = final_nodes.sum(dim=0)
    final_edges = torch.stack(edge_outputs) * edge_weights.view(-1,1,1) 
    final_edges = final_edges.sum(dim=0)

    return final_nodes, final_edges


class Decoder(nn.Module):
  """The Decoder: :math: `\mathcal{G} \rightarrow \mathcal{Y}` extracts the 
  dynamics information from the nodes of the final latent graph, 
  :math: `y_i = \delta v (v_i^M)`

  """

  def __init__(
          self,
          nnode_in: int,
          nnode_out: int,
          nmlp_layers: int,
          mlp_hidden_dim: int):
    """The Decoder coder's learned function, :math: `\detla v`, is an MLP. 
    After the Decoder, the future position and velocity are updated using an 
    Euler integrator, so the :math: `yi` corresponds to accelerations, 
    :math: `\"{p}_i`, with 2D or 3D dimension, depending on the physical domain.

    Args:
      nnode_in: Number of node inputs (latent dimension of size 128).
      nnode_out: Number of node outputs (particle dimension).
      nmlp_layer: Number of hidden layers in the MLP (typically of size 2).
      mlp_hidden_dim: Size of the hidden layer (latent dimension of size 128).
    """
    super(Decoder, self).__init__()
    self.node_fn = build_mlp(
        nnode_in, [mlp_hidden_dim for _ in range(nmlp_layers)], nnode_out)

  def forward(self,
              x: torch.tensor):
    """The forward hook runs when the Decoder class is instantiated

    Args:
      x: Particle state representation as a torch tensor with shape 
        (nparticles, nnode_in)

    """
    return self.node_fn(x)


class EncodeProcessDecode(nn.Module):
  def __init__(
      self,
      nnode_in_features: int,
      nnode_out_features: int,
      nedge_in_features: int,
      latent_dim: int,
      nmessage_passing_steps: int,
      nmlp_layers: int,
      mlp_hidden_dim: int,
      timestep_embed_dim: int,
  ):
    """Encode-Process-Decode function approximator for learnable simulator.

    Args:
      nnode_in_features: Number of node input features (for 2D = 30, 
        calculated as [10 = 5 times steps * 2 positions (x, y) + 
        4 distances to boundaries (top/bottom/left/right) + 
        16 particle type embeddings]).
      nnode_out_features:  Number of node outputs (particle dimension).
      nedge_in_features: Number of edge input features (for 2D = 3, 
        calculated as [2 (x, y) relative displacements between 2 particles + 
        distance between 2 particles]).
      latent_dim: Size of latent dimension (128)
      nmlp_layer: Number of hidden layers in the MLP (typically of size 2).
      mlp_hidden_dim: Size of the hidden layer (latent dimension of size 128).

    """
    super(EncodeProcessDecode, self).__init__()
    self._encoder = Encoder(
        nnode_in_features=nnode_in_features,
        nnode_out_features=latent_dim,
        nedge_in_features=nedge_in_features,
        nedge_out_features=latent_dim,
        nmlp_layers=nmlp_layers,
        mlp_hidden_dim=mlp_hidden_dim,
    )
    self._processor = Processor(
        nnode_in=latent_dim,
        nnode_out=latent_dim,
        nedge_in=latent_dim,
        nedge_out=latent_dim,
        nmessage_passing_steps=nmessage_passing_steps,
        nmlp_layers=nmlp_layers,
        mlp_hidden_dim=mlp_hidden_dim,
        timestep_embed_dim=timestep_embed_dim,
    )
    self._decoder = Decoder(
        nnode_in=latent_dim,
        nnode_out=nnode_out_features,
        nmlp_layers=nmlp_layers,
        mlp_hidden_dim=mlp_hidden_dim,
    )

  def forward(self,
              x: torch.tensor,
              edge_index: torch.tensor,
              edge_features: torch.tensor,
              timestep: torch.tensor):
    """The forward hook runs at instatiation of EncodeProcessorDecode class.

      Args:
        x: Particle state representation as a torch tensor with shape 
          (nparticles, nnode_in_features)
        edge_index: A torch tensor list of source and target nodes with shape 
          (2, nedges)
        edge_features: Edge features as a torch tensor with shape 
          (nedges, nedge_in_features)
          
      Returns:
        x: Particle state representation as a torch tensor with shape
          (nparticles, nnode_out_features)
    """
    # import pdb; pdb.set_trace()
    x, edge_features = self._encoder(x, edge_features)
    x, edge_features = self._processor(x, edge_index, edge_features, timestep)
    x = self._decoder(x)
    return x
