import jax.numpy as jnp
import flax.linen as nn
from typing import Any
import jax
import ml_collections
import utils_qm9
from . import utils
from configs.datasets_config import get_dataset_info

#============================================================#
# Underlying MLPs for GCL
class edge_mlp(nn.Module):
  hidden_nf: int
  input_edge: int
  edges_in_d: int
  act_fn: Any

  @nn.compact
  def __call__(self, x):
    """
    edge_mlp
      Input:  (B, input_edge + edges_in_d)
      Output: (B, hidden_nf)
    """
    assert x.shape[-1] == self.input_edge + self.edges_in_d
    x = nn.Dense(self.hidden_nf)(x)
    x = self.act_fn(x)
    x = nn.Dense(self.hidden_nf)(x)
    x = self.act_fn(x)
    return x

class node_mlp(nn.Module):
  hidden_nf: int
  input_nf: int
  nodes_att_dim: int
  output_nf: int
  act_fn: Any

  @nn.compact
  def __call__(self, x):
    """
    node_mlp
      Input:  (B, hidden_nf + input_nf + nodes_att_dim)
      Output: (B, output_nf)
    """
    assert x.shape[-1] == self.hidden_nf + self.input_nf + self.nodes_att_dim
    x = nn.Dense(self.hidden_nf)(x)
    x = self.act_fn(x)
    x = nn.Dense(self.output_nf)(x)
    return x

class att_mlp(nn.Module):
  hidden_nf: int

  @nn.compact
  def __call__(self, x):
    """
    att_mlp
      Input:  (B, hidden_nf)
      Output: (B, 1)
    """
    assert x.shape[-1] == self.hidden_nf
    x = nn.Dense(1)(x)
    x = nn.sigmoid(x)
    return x


def unsorted_segment_sum(
    data, segment_ids, num_segments, normalization_factor, aggregation_method: str
):
    """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
    Normalization: 'sum' or 'mean'.

    result[i] = sum_j data[j...], such that segment_ids[j...] = i

    Output shape: (num_segments,) + data.shape[1:]
    """
    result = jax.ops.segment_sum(data, segment_ids, num_segments)
    
    # Normalize
    assert aggregation_method in ['sum', 'mean']
    if aggregation_method == "sum":
      result = result / normalization_factor
    elif aggregation_method == "mean":
      raise NotImplementedError()
      norm = jnp.zeros_like(result)
      norm.scatter_add_(0, segment_ids, data.new_ones(data.shape))
      norm[norm == 0] = 1
      result = result / norm
    else:
      raise ValueError()

    return result
#============================================================#
# edge and node model for GCL
class edge_model(nn.Module):
  edge_mlp: nn.Module
  att_mlp: nn.Module
  attention: bool

  @nn.compact
  def __call__(self, source, target, edge_attr, edge_mask):
    """
    edge_model
    Input
      source: source node, dimension
      target: target node, dimension
      edge_attr: edge attribute, dimension
      edge_mask: edge mask
    Output
      out: output node
      mij: mask?
    """
    if edge_attr is None: # Unused.
      out = jnp.concatenate([source, target], axis=1)
    else:
      out = jnp.concatenate([source, target, edge_attr], axis=1)
      
    mij = self.edge_mlp(out)

    if self.attention:
      att_val = self.att_mlp(mij)
      out = mij * att_val
    else:
      out = mij

    if edge_mask is not None:
      out = out * edge_mask

    return out, mij

class node_model(nn.Module):
  normalization_factor: int
  aggregation_method: Any
  node_mlp: nn.Module

  @nn.compact
  def __call__(self, x, edge_index, edge_attr, node_attr):
    """
    Input
      x: input
      edge_index: int
      edge_attr: 
      node_attr: 
    Output
      out: 
      agg: 
    """
    row, col = edge_index
    agg = unsorted_segment_sum(
      edge_attr,
      jnp.array(row),
      num_segments=x.shape[0],
      normalization_factor=self.normalization_factor,
      aggregation_method=self.aggregation_method,
    )
    if node_attr is not None:
      agg = jnp.concatenate([x, agg, node_attr], axis=-1)
    else:
      agg = jnp.concatenate([x, agg], axis=-1)
    out = x + self.node_mlp(agg)
    return out, agg
#============================================================#
class GCL(nn.Module):
  input_nf:             int
  output_nf:            int
  hidden_nf:            int
  normalization_factor: float
  aggregation_method:   Any
  edges_in_d:           int =   0
  nodes_att_dim:        int =   0
  act_fn:               Any =   nn.swish
  attention:            bool =  False

  @nn.compact
  def __call__(
    self,
    h,
    edge_index,
    edge_attr=None,
    node_attr=None,
    node_mask=None,
    edge_mask=None,
  ):
    """
    edge feature, mij <- edge_model(source node, target node, ...)
    new node feature  <- node_model(node, edge index, edge feature, node attribute)
    """
    # Initialization part
    input_edge = self.input_nf * 2
    _edge_mlp = edge_mlp(hidden_nf=self.hidden_nf, input_edge=input_edge, edges_in_d=self.edges_in_d, act_fn=self.act_fn)
    _node_mlp = node_mlp(hidden_nf=self.hidden_nf, input_nf=self.input_nf, nodes_att_dim=self.nodes_att_dim, output_nf=self.output_nf, act_fn=self.act_fn)
    if self.attention:
      _att_mlp = att_mlp(hidden_nf=self.hidden_nf)

    # edge and node model
    _edge_model = edge_model(_edge_mlp, _att_mlp, self.attention)
    _node_model = node_model(self.normalization_factor, self.aggregation_method, _node_mlp)

    # Network part
    row, col = edge_index
    edge_feat, mij = _edge_model(h[row], h[col], edge_attr, edge_mask)
    h, _ = _node_model(h, edge_index, edge_feat, node_attr)
    if node_mask is not None:
      h = h * node_mask
    return h, mij
#============================================================#
# EquivariantUpdate
class EquivariantUpdate(nn.Module):
  hidden_nf: int
  normalization_factor: float
  aggregation_method: str

  edges_in_d: int = 1
  act_fn: Any = nn.swish
  tanh: bool = False
  coords_range: float = 10.0

  def setup(self):
    input_edge = self.hidden_nf * 2 + self.edges_in_d
    layer = nn.Dense(1, use_bias=False, kernel_init=nn.initializers.glorot_uniform()) # TODO: scale down
    self.coord_mlp = nn.Sequential(
      [
        nn.Dense(self.hidden_nf),
        self.act_fn,
        nn.Dense(self.hidden_nf),
        self.act_fn,
        layer,
      ]
    )
  
  def coord_model(self, h, coord, edge_index, coord_diff, edge_attr, edge_mask):
    row, col = edge_index # from, to
    input_tensor = jnp.concatenate([h[row], h[col], edge_attr], axis=1)
    if self.tanh:
      trans = coord_diff * jnp.tanh(self.coord_mlp(input_tensor)) * self.coords_range
    else:
      trans = coord_diff * self.coord_mlp(input_tensor)
    if edge_mask is not None:
      trans = trans * edge_mask
    agg = unsorted_segment_sum(trans, row, num_segments=coord.shape[0], normalization_factor=self.normalization_factor, aggregation_method=self.aggregation_method)
    coord += agg
    return coord

  def __call__(self, h, coord, edge_index, coord_diff, edge_attr=None, node_mask=None, edge_mask=None):
    coord = self.coord_model(h, coord, edge_index, coord_diff, edge_attr, edge_mask)
    if node_mask is not None:
      coord = coord * node_mask
    return coord
#============================================================#
class EquivariantBlock(nn.Module):

  hidden_nf:            int
  edge_feat_nf:         int = 2
  device:               str = "cpu"
  act_fn:               Any = nn.swish
  n_layers:             int = 2
  attention:            bool = True
  norm_diff:            bool = True
  tanh:                 bool = False
  coords_range:         int = 15
  norm_constant:        int = 1
  sin_embedding:        Any = None
  normalization_factor: float = 100
  aggregation_method:   str = "sum"

  @nn.compact
  def __call__(self, h, x, edge_index, node_mask=None, edge_mask=None, edge_attr=None):
    # Initialization
    coords_range_layer = float(self.coords_range)
    gcl_modules = []
    for i in range(0, self.n_layers):
      gcl_modules.append(GCL(self.hidden_nf,
                             self.hidden_nf,
                             self.hidden_nf,
                             edges_in_d=self.edge_feat_nf,
                             act_fn=self.act_fn,
                             attention=self.attention,
                             normalization_factor=self.normalization_factor,
                             aggregation_method=self.aggregation_method))

    gcl_equiv = EquivariantUpdate(self.hidden_nf,
                                  edges_in_d=self.edge_feat_nf,
                                  act_fn=nn.swish,
                                  tanh=self.tanh,
                                  coords_range=coords_range_layer,
                                  normalization_factor=self.normalization_factor,
                                  aggregation_method=self.aggregation_method)

    # Forward
    distances, coord_diff = utils_qm9.coord2diff(x, edge_index, self.norm_constant)
    if self.sin_embedding:
      distances = utils_qm9.sin_embedding(distances)
    edge_attr = jnp.concatenate([distances, edge_attr], axis=1)
    for i in range(0, self.n_layers):
      h, _ = gcl_modules[i](h,
                            edge_index,
                            edge_attr=edge_attr,
                            node_mask=node_mask,
                            edge_mask=edge_mask)
    x = gcl_equiv(h, x, edge_index, coord_diff, edge_attr, node_mask, edge_mask)

    # Important, the bias of the last linear might be non-zero
    if node_mask is not None:
        h = h * node_mask
    return h, x
#============================================================#
# @utils.register_model(name='egnn')
class EGNN(nn.Module):
  config: ml_collections.ConfigDict

  # Dimension parameters
  # in_node_nf:           int # Defined
  # in_edge_nf:           int # 1
  # hidden_nf:            int # config.model.hidden_nf
  out_node_nf:          Any = None

  # Remaining parameters
  coords_range:         float = 15
  norm_constant:        float = 1
  normalization_factor: float = 100
  aggregation_method:   str = "sum"

  # Goes to config
  # tanh:                 bool = False
  # attention:            bool = False
  # norm_diff:            bool = True
  # n_layers:             int = 3
  # inv_sublayers:        int = 2
  # sin_embedding:        bool = False
  # act_fn:               Any = nn.swish

  @nn.compact
  def __call__(self, batch):
    """
    EGNN model
    Input (Non-preprocessed input data.)
      h: features,              shape (B, n_atoms, n_features)
      x: input position,        shape (B, n_atoms, 3)
      edge_index: edge indices, list with two elements []
      node_mask: node mask,     shape (B, n_atoms, 1)
      edge_mask: edge mask,     shape (B * n_atoms * n_atoms, 1)

    Parameters
      in_node_nf: dset_info['atom_decoder'] + int(context_node_nf) + int(condition_time)
      in_edge_nf: fixed to 1, not used
      hidden_nf:  config.model.nf
      out_node_nf: in_node_nf (default)
    """
    # Initialization
    config = self.config
    dataset_info = get_dataset_info(config.data.dataset, config.data.remove_h)
    in_node_nf = len(dataset_info['atom_decoder']) + int(config.model.include_charges)
    out_node_nf = in_node_nf if self.out_node_nf is None else self.out_node_nf
    if config.model.sin_embedding:
      sin_embedding_fn = utils_qm9.sin_embedding
      edge_feat_nf = sin_embedding_fn(None) * 2
    else:
      sin_embedding_fn = None
      edge_feat_nf = 2
    
    embedding = nn.Dense(config.model.nf)
    embedding_out = nn.Dense(out_node_nf)
    
    EqBlock = []
    for i in range(0, config.model.n_layers):
      EqBlock.append(EquivariantBlock(hidden_nf=config.model.nf,
                                      edge_feat_nf=edge_feat_nf,
                                      act_fn=nn.swish,
                                      n_layers=config.model.inv_sublayers,
                                      attention=config.model.attention,
                                      norm_diff=config.model.norm_diff,
                                      tanh=config.model.tanh,
                                      coords_range=self.coords_range,
                                      norm_constant=self.norm_constant,
                                      sin_embedding=config.model.sin_embedding,
                                      normalization_factor=self.normalization_factor,
                                      aggregation_method=self.aggregation_method,))

    # forward (B: batch size, N: n_atoms, C: number of contexts)
    h = batch['h']                   # (B * N, 5 (atom-cat) + 1 (charge-int) + C + aug_dim)
    x = batch['x']                   # (B * N, 3)
    edge_index = batch['edges']      # [(B * N * N,), (B * N * N,)]
    node_mask = batch['node_mask']   # (B * N, 1)
    edge_mask = batch['edge_mask']   # (B * N * N, 1)
    t = batch['t']                   # (B * N, 1)
    h = jnp.concatenate([h, t], axis=-1)

    distances, _ = utils_qm9.coord2diff(x, edge_index) # (B * N * N, 1)

    if config.model.sin_embedding:
      distances = sin_embedding_fn(distances)
    h = embedding(h)
    for i in range(0, config.model.n_layers):
      h, x = EqBlock[i](h,
                        x,
                        edge_index,
                        node_mask=node_mask,
                        edge_mask=edge_mask,
                        edge_attr=distances)
    
    # Important, the bias of the last linear might be non-zero
    h = embedding_out(h)
    if node_mask is not None:
      h = h * node_mask
    return h, x
