import flax.linen as nn
from typing import Any
import jax.numpy as jnp
from utils import batch_mul
import functools
from jax.ops import segment_sum

#============================================================#
# Underlying MLPs
class edge_mlp(nn.Module):
  hidden_nf: int
  input_edge: int
  edges_in_d: int
  act_fn: Any

  @nn.compact
  def __call__(self, x):
    """
    edge_mlp
      Input:  (B, input_edge + edges_in_d)
      Output: (B, hidden_nf)
    """
    assert x.shape[-1] == self.input_edge + self.edges_in_d
    x = nn.Dense(self.hidden_nf)
    x = self.act_fn(x)
    x = nn.Dense(self.hidden_nf)
    x = self.act_fn(x)
    return x

class node_mlp(nn.Module):
  hidden_nf: int
  input_nf: int
  nodes_att_dim: int
  output_nf: int
  act_fn: Any

  @nn.compact
  def __call__(self, x):
    """
    node_mlp
      Input:  (B, hidden_nf + input_nf + nodes_att_dim)
      Output: (B, output_nf)
    """
    assert x.shape[-1] == self.hidden_nf + self.input_nf + self.nodes_att_dim
    x = nn.Dense(self.hidden_nf)
    x = self.act_fn(x)
    x = nn.Dense(self.output_nf)
    return x

class att_mlp(nn.Module):
  hidden_nf: int

  @nn.compact
  def __call__(self, x):
    """
    att_mlp
      Input:  (B, hidden_nf)
      Output: (B, 1)
    """
    assert x.shape[-1] == self.hidden_nf
    x = nn.Dense(1)
    x = nn.sigmoid(x)
    return x
#============================================================#
# edge and node model
class edge_model(nn.Module):
  edge_mlp: nn.Module
  att_mlp: nn.Module

  @nn.compact
  def __call__(self, source, target, edge_attr, edge_mask):
    """
    edge_model
    Input
      source:
      target:
      edge_attr:
      edge_mask:
    Output
      out:
      mij:
    """
    if edge_attr is None:
      out = jnp.concatenate([source, target], dim=-1)
    else:
      out = jnp.concatenate([source, target, edge_attr], dim=-1)
    mij = self.edge_mlp(out)

    if self.attention:
      att_val = self.att_mlp(mij)
      out = batch_mul(mij, att_val)
    else:
        out = mij

    if edge_mask is not None:
        out = out * edge_mask

    return out, mij

class node_model(nn.Module):
  normalization_factor: int
  aggregation_method: Any
  node_mlp: nn.Module

  @nn.compact
  def __call__(self, x, edge_index, edge_attr, node_attr):
    """
    Input
      x: input
      edge_index: int
      edge_attr: 
      node_attr: 
    Output
      out: 
      agg: 
    """
    row, col = edge_index
    agg = normalized_segment_sum(
      edge_attr,
      row,
      num_segments=x.shape[0],
      normalization_factor=self.normalization_factor,
      aggregation_method=self.aggregation_method,
    )
    if node_attr is not None:
      agg = jnp.concatenate([x, agg, node_attr], dim=-1)
    else:
      agg = jnp.concatenate([x, agg], dim=-1)
    out = x + self.node_mlp(agg)
    return out, agg
#============================================================#
class GCL(nn.Module):
  input_nf: int
  output_nf: int
  hidden_nf: int
  normalization_factor: float
  aggregation_method: Any
  edges_in_d: int = 0
  nodes_att_dim: int = 0
  act_fn: Any = nn.swish
  attention: bool = False

  @nn.compact
  def __call__(
    self,
    h,
    edge_index,
    edge_attr=None,
    node_attr=None,
    node_mask=None,
    edge_mask=None,
  ):
    # Initialization part
    input_edge = self.input_nf * 2
    self.edge_mlp = edge_mlp(hidden_nf=self.hidden_nf, input_edge=input_edge, edges_in_d=self.edges_in_d, act_fn=self.act_fn)
    self.node_mlp = node_mlp(hidden_nf=self.hidden_nf, input_nf=self.input_nf, nodes_att_dim=self.nodes_att_dim, output_nf=self.output_nf, act_fn=self.act_fn)
    if self.attention:
      self.att_mlp = att_mlp(hidden_nf=self.hidden_nf)

    # edge and node model
    self.edge_model = edge_model(self.edge_mlp, self.att_mlp)
    self.node_model = node_model(self.normalization_factor, self.aggregation_method, self.node_mlp)

    # Network part
    row, col = edge_index
    edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask)
    h, _ = self.node_model(h, edge_index, edge_feat, node_attr)
    if node_mask is not None:
      h = h * node_mask
    return h, mij
#============================================================#
# Equivariant update class
class EquivariantBlock(nn.Module):
  config: Any

  @nn.compact
  def __call__(self, h, x, edge_index, node_mask=None, edge_mask=None, edge_attr=None):
    config = self.config
    hidden_nf = config.model.hidden_nf
    n_layers = config.model.n_layers
    coords_range_layer = float(config.model.coords_range)
    norm_diff = config.model.norm_diff
    norm_constant = config.model.norm_constant
    sin_embedding = config.model.sin_embedding
    normalization_factor = config.model.normalization_factor
    aggregation_method = config.model.aggregation_method

    # layers
    GCLLayer = dict()
    for i in range(n_layers):
      GCLLayer[i] = GCL(
        input_nf=hidden_nf,
        output_nf=hidden_nf,
        hidden_nf=hidden_nf,
        normalization_factor=normalization_factor,
        aggregation_method=aggregation_method,
        edges_in_d=config.model.edge_feat_nf,
        act_fn=config.model.act_fn,
        attention=config.model.attention,
      )

    GCLEquiv = EquivalentUpdate(
      hidden_nf,
      edges_in_d=config.model.edges_feat_nf,
      act_fn=nn.swish,
      tanh=config.model.tanh,
      coords_range=coords_range_layer,
      normalization_factor=normalization_factor,
      aggregation_method=aggregation_method,
    )

    # forward
    distances, coord_diff = coord2diff(x, edge_index, norm_constant)
    if sin_embedding is not None:
      distances = sin_embedding(distances)
    edge_attr = jnp.concatenate([distances, edge_attr], axis=-1)
    for i in range(n_layers):
      h, _ = GCLLayer[i](
        h,
        edge_index,
        einpu=edge_attr,
        node_mask=node_mask,
        edge_mask=edge_mask,
      )
    x = GCLEquiv(
      h, x, edge_index, coord_diff, edge_attr, node_mask, edge_mask
    )

    if node_mask is not None:
      h = h * node_mask
    return h, x


class SinusoidsEmbeddingNew(nn.Module):
  max_res: float = 15.0
  min_res: float = 15.0 / 2000.0
  div_factor: int = 4

  @nn.compact
  def __call__(self, x):
    n_frequencies = int(jnp.log(self.max_res / self.min_res, self.div_factor)) + 1
    frequencies = (
      2 * jnp.pi * self.div_factor ** jnp.arange(n_frequencies) / self.max_res
    )

    # forward
    x = jnp.sqrt(x + 1e-8)
    emb = x * frequencies[None, :]
    emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
    return emb


def coord2diff(x, edge_index, norm_constant=1):
  row, col = edge_index
  coord_diff = x[row] - x[col]
  radial = jnp.expand_dims(jnp.sum(coord_diff ** 2, axis=1), axis=1)
  norm = jnp.sqrt(radial + 1e-8)
  coord_diff = coord_diff / (norm + norm_constant)
  return radial, coord_diff


def normalized_segment_sum(
  data,
  segment_ids,
  num_segments,
  normalization_factor,
  aggregation_method
):
  assert aggregation_method in ["mean", "sum"]
  x = segment_sum(data, segment_ids, num_segments)
  if aggregation_method == "sum":
    return x / normalization_factor
  elif aggregation_method == "mean":
    # TODO
    raise NotImplementedError()
  else:
    raise ValueError()

class EquivalentUpdate(nn.Module):
  hidden_nf: int
  normalization_factor: int
  aggregation_method: Any
  edges_in_d: int = 1
  act_fn: Any = nn.swish
  tanh: bool = False
  coords_range: float = 10.0

  def coord_model(self, h, coord, edge_index, coord_diff, edge_attr, edge_mask):
    row, col = edge_index
    input_tensor = jnp.concatenate([h[row], h[col], edge_attr], axis=1)
    if self.tanh:
      trans = (
        coord_diff * nn.tanh(self.coord_mlp(input_tensor)) * self.coords_range
      )
    else:
      trans = coord_diff * self.coord_mlp(input_tensor)
    
    if edge_mask is not None:
      trans = trans * edge_mask
    agg = normalized_segment_sum(
      trans,
      row,
      num_segments=coord.shape[0],
      normalization_factor=self.normalization_factor,
      aggregation_method=self.aggregation_method
    )
    coord = coord + agg
    return coord


  @nn.compact
  def __call__(self, h, coord, edge_index, coord_diff, edge_attr=None, node_mask=None, edge_mask=None):
    coord = self.coord_model(h, coord, edge_index, coord_diff, edge_attr, edge_mask)
    if node_mask is not None:
      coord = coord * node_mask
    return coord
#============================================================#
# EGNN (used fraction of the code in MolFM)
class EGNN(nn.Module):
  config: Any
  
  @nn.compact
  def __call__(self, h, x, edge_index, node_mask=None, edge_mask=None):
    """
    Input
      h: node feature matrix, (N, d)
      x: atom coordinate matrix, (N, 3)
      edge_index:
      node_mask:
      edge_mask:
    Output
      h: node feature drift, (N, d)
      x: atom coordinate drift, (N, 3)
    """
    # Initialization
    config = self.config
    assert config.model.name == 'egnn'
    out_node_nf = config.model.in_node_nf if config.model.out_node_nf is None else config.model.out_node_nf
    hidden_nf = config.model.hidden_nf
    n_layers = config.model.n_layers
    coords_range_layer = float(config.model.coords_range / n_layers)
    norm_diff = config.model.norm_diff
    normalization_factor = config.model.normalization_factor
    aggregation_method = config.model.aggregation_method

    if config.model.sin_embedding:
      sin_embedding = SinusoidsEmbeddingNew() # TODO: Implement SinusoidsEmbeddingNew
      edge_feat_nf = sin_embedding.dim * 2
    else:
      sin_embedding = None
      edge_feat_nf = 2

    EBlock = functools.partial(EquivariantBlock,  # TODO: Implement EquivariantBlock
                               hidden_nf,
                               edge_feat_nf=edge_feat_nf,
                               act_fn=config.model.act_fn,
                               n_layers=config.model.inv_sublayers,
                               attention=config.model.attention,
                               norm_diff=norm_diff,
                               tanh=config.model.tanh,
                               coords_range=config.model.coords_range,
                               norm_constant=config.model.norm_constant,
                               sin_embedding=sin_embedding,
                               normalization_factor=normalization_factor,
                               aggregation_method=aggregation_method)

    # forward
    distances, _ = coord2diff(x, edge_index) # TODO
    if sin_embedding is not None:
      distances = sin_embedding(distances)
    h = nn.Dense(hidden_nf)(h)
    for i in range(0, self.n_layers):
      h, x = EBlock(name=f'e_block_{i}')(
        h,
        x,
        edge_index,
        node_mask=node_mask,
        edge_mask=edge_mask,
        edge_attr=distances,
      )

      # The bias of the last linear might be non-zero
      h = nn.Dense(out_node_nf)(h)
      if node_mask is not None:
        h = h * node_mask
      return h, x
#============================================================#
