import torch
from torch import nn, einsum, broadcast_tensors
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# types

from typing import Optional, List, Union

# pytorch geometric

try:
    import torch_geometric
    from torch_geometric.nn import MessagePassing
    from torch_geometric.typing import Adj, Size, OptTensor, Tensor
    from torch_scatter import scatter_mean
except:
    Tensor = OptTensor = Adj = MessagePassing = Size = object
    PYG_AVAILABLE = False

    # to stop throwing errors from type suggestions
    Adj = object
    Size = object
    OptTensor = object
    Tensor = object

from .egnn_pytorch import *


# global linear attention

class Attention_Sparse(Attention):
    def __init__(self, dim, heads=8, dim_head=64):
        """ Wraps the attention class to operate with pytorch-geometric inputs. """
        super(Attention_Sparse, self).__init__(dim, heads=8, dim_head=64)

    def sparse_forward(self, x, context, batch=None, batch_uniques=None, mask=None):
        assert batch is not None or batch_uniques is not None, "Batch/(uniques) must be passed for block_sparse_attn"
        if batch_uniques is None:
            batch_uniques = torch.unique(batch, return_counts=True)
        # only one example in batch - do dense - faster
        if batch_uniques[0].shape[0] == 1:
            x, context = map(lambda t: rearrange(t, 'h d -> () h d'), (x, context))
            return self.forward(x, context, mask=None).squeeze()  #  get rid of batch dim
        # multiple examples in batch - do block-sparse by dense loop
        else:
            x_list = []
            aux_count = 0
            for bi, n_idxs in zip(*batch_uniques):
                x_list.append(
                    self.sparse_forward(
                        x[aux_count:aux_count + n_idxs],
                        context[aux_count:aux_count + n_idxs],
                        batch_uniques=(bi.unsqueeze(-1), n_idxs.unsqueeze(-1))
                    )
                )
            return torch.cat(x_list, dim=0)


class GlobalLinearAttention_Sparse(nn.Module):
    def __init__(
            self,
            *,
            dim,
            heads=8,
            dim_head=64
    ):
        super().__init__()
        self.norm_seq = torch_geometric.nn.norm.LayerNorm(dim)
        self.norm_queries = torch_geometric.nn.norm.LayerNorm(dim)
        self.attn1 = Attention_Sparse(dim, heads, dim_head)
        self.attn2 = Attention_Sparse(dim, heads, dim_head)

        # can't concat pyg norms with torch sequentials
        self.ff_norm = torch_geometric.nn.norm.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x, queries, batch=None, batch_uniques=None, mask=None):
        res_x, res_queries = x, queries
        x, queries = self.norm_seq(x, batch=batch), self.norm_queries(queries, batch=batch)

        induced = self.attn1.sparse_forward(queries, x, batch=batch, batch_uniques=batch_uniques, mask=mask)
        out = self.attn2.sparse_forward(x, induced, batch=batch, batch_uniques=batch_uniques)

        x = out + res_x
        queries = induced + res_queries

        x_norm = self.ff_norm(x, batch=batch)
        x = self.ff(x_norm) + x_norm
        return x, queries


#  define pytorch-geometric equivalents
class EGNN_Sparse(MessagePassing):
    """ Different from the above since it separates the edge assignment
        from the computation (this allows for great reduction in time and
        computations when the graph is locally or sparse connected).
        * aggr: one of ["add", "mean", "max"]
    """

    def __init__(
            self,
            feats_dim,
            pos_dim=3,
            edge_attr_dim=0,
            m_dim=16,
            fourier_features=0,
            soft_edge=0,
            norm_feats=False,
            norm_coors=True,
            norm_coors_scale_init=1e-2,
            update_feats=True,
            update_edge=False,
            update_coors=True,
            update_global=True,
            dropout=0.,
            coor_weights_clamp_value=None,
            aggr="add",
            mlp_num=2,
            **kwargs
    ):
        assert aggr in {'add', 'sum', 'max', 'mean'}, 'pool method must be a valid option'
        assert update_feats or update_coors, 'you must update either features, coordinates, or both'
        kwargs.setdefault('aggr', aggr)
        super(EGNN_Sparse, self).__init__(**kwargs)
        # model params
        self.fourier_features = fourier_features
        self.feats_dim = feats_dim
        self.pos_dim = pos_dim
        self.m_dim = m_dim
        self.soft_edge = soft_edge
        self.norm_feats = norm_feats
        self.norm_coors = norm_coors
        self.update_coors = update_coors
        self.update_feats = update_feats
        self.update_edge = update_edge
        self.update_global = update_global
        self.coor_weights_clamp_value = None
        self.mlp_num = mlp_num
        self.edge_input_dim = edge_attr_dim
        self.message_input_dim = (fourier_features * 2) + edge_attr_dim + 1 + (feats_dim * 2)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        #  EDGES
        if self.mlp_num > 2:
            self.edge_mlp = nn.Sequential(
                nn.Linear(self.edge_input_dim, self.edge_input_dim * 8),
                self.dropout,
                SiLU(),
                nn.Linear(self.edge_input_dim * 8, self.edge_input_dim * 4),
                self.dropout,
                SiLU(),
                nn.Linear(self.edge_input_dim * 4, self.edge_input_dim * 2),
                self.dropout,
                SiLU(),
                nn.Linear(self.edge_input_dim * 2, m_dim),
                SiLU(),
            ) if update_feats else None
        else:
            self.edge_mlp = nn.Sequential(
                nn.Linear(self.edge_input_dim + self.feats_dim * 2, self.edge_input_dim * 2 + self.feats_dim * 4),
                self.dropout,
                SiLU(),
                nn.Linear(self.edge_input_dim * 2 + self.feats_dim * 4, self.edge_input_dim),
                SiLU()
            )
        self.message_mlp = nn.Sequential(
            nn.Linear(self.message_input_dim, self.message_input_dim * 2),
            self.dropout,
            SiLU(),
            nn.Linear(self.message_input_dim * 2, m_dim),
            SiLU()
        )
        self.edge_weight = nn.Sequential(nn.Linear(m_dim, 1),
                                         nn.Sigmoid()
                                         ) if soft_edge else None

        if self.update_global:
            self.global_mlp = nn.Sequential(
                nn.Linear(2 * feats_dim, 2 * feats_dim),
                nn.ReLU(),
                nn.Linear(2 * feats_dim, 2 * feats_dim),
                nn.ReLU(),
                nn.Linear(2 * feats_dim, feats_dim),
                nn.Sigmoid()
            )

        # NODES - can't do identity in node_norm bc pyg expects 2 inputs, but identity expects 1.
        self.node_norm = torch_geometric.nn.norm.LayerNorm(feats_dim) if norm_feats else None
        self.edge_norm = torch_geometric.nn.norm.LayerNorm(self.edge_input_dim) if self.update_edge else None
        self.coors_norm = CoorsNorm(scale_init=norm_coors_scale_init) if norm_coors else nn.Identity()
        if self.mlp_num > 2:
            self.node_mlp = nn.Sequential(
                nn.Linear(feats_dim + m_dim, feats_dim * 8),
                self.dropout,
                SiLU(),
                nn.Linear(feats_dim * 8, feats_dim * 4),
                self.dropout,
                SiLU(),
                nn.Linear(feats_dim * 4, feats_dim * 2),
                self.dropout,
                SiLU(),
                nn.Linear(feats_dim * 2, feats_dim),
            ) if update_feats else None
        else:
            self.node_mlp = nn.Sequential(
                nn.Linear(feats_dim + m_dim, feats_dim * 2),
                self.dropout,
                SiLU(),
                nn.Linear(feats_dim * 2, feats_dim),
            ) if update_feats else None

        #  COORS
        self.coors_mlp = nn.Sequential(
            nn.Linear(m_dim, m_dim * 4),
            self.dropout,
            SiLU(),
            nn.Linear(self.m_dim * 4, 1)
        ) if update_coors else None

        self.apply(self.init_)

    def init_(self, module):
        if type(module) in {nn.Linear}:
            # seems to be needed to keep the network from exploding to NaN with greater depths
            nn.init.xavier_normal_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x: Tensor, edge_index: Adj,
                edge_attr: OptTensor = None, batch: Adj = None,
                angle_data: List = None, size: Size = None) -> Tensor:
        """ Inputs:
            * x: (n_points, d) where d is pos_dims + feat_dims
            * edge_index: (2, n_edges)
            * edge_attr: tensor (n_edges, n_feats) excluding basic distance feats.
            * batch: (n_points,) long tensor. specifies xloud belonging for each point
            * angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor.
            * size: None
        """
        coors, feats = x[:, :self.pos_dim], x[:, self.pos_dim:]

        rel_coors = coors[edge_index[0]] - coors[edge_index[1]]
        rel_dist = (rel_coors ** 2).sum(dim=-1, keepdim=True)

        if self.fourier_features > 0:
            rel_dist = fourier_encode_dist(rel_dist, num_encodings=self.fourier_features)
            rel_dist = rearrange(rel_dist, 'n () d -> n d')

        # if self.update_edge:
        #     edge_batch = batch[edge_index[0]]
        #     edge_attr_feats = self.edge_mlp(edge_attr)
        #     edge_attr = self.edge_norm(self.dropout(edge_attr_feats) + edge_attr, edge_batch)

        if exists(edge_attr):
            edge_attr_feats = torch.cat([edge_attr, rel_dist], dim=-1)
        else:
            edge_attr_feats = rel_dist

        hidden_out, coors_out = self.propagate(edge_index, x=feats, edge_attr=edge_attr_feats,
                                               coors=coors, rel_coors=rel_coors,
                                               batch=batch)
        if self.update_edge:
            hidden_out_i = hidden_out[edge_index[0]]
            hidden_out_j = hidden_out[edge_index[1]]
            hidden_edge_attr = torch.cat([hidden_out_i, edge_attr, hidden_out_j], dim=-1)
            edge_batch = batch[edge_index[0]]
            hidden_edge_attr = self.edge_mlp(hidden_edge_attr)
            edge_attr = self.edge_norm(self.dropout(hidden_edge_attr) + edge_attr, edge_batch)

        if self.update_global:
            hidden_global = scatter_mean(hidden_out, batch, dim=0)[batch]
            hidden_global = torch.cat([hidden_out, hidden_global], dim=-1)
            hidden_out = hidden_out * self.global_mlp(hidden_global)

        if self.update_edge:
            return torch.cat([coors_out, hidden_out], dim=-1), edge_attr
        else:
            return torch.cat([coors_out, hidden_out], dim=-1)

    def message(self, x_i, x_j, edge_attr) -> Tensor:
        m_ij = self.message_mlp(torch.cat([x_i, x_j, edge_attr], dim=-1))
        return m_ij

    def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
        """The initial call to start propagating messages.
            Args:
            `edge_index` holds the indices of a general (sparse)
                assignment matrix of shape :obj:`[N, M]`.
            size (tuple, optional) if none, the size will be inferred
                and assumed to be quadratic.
            **kwargs: Any additional data which is needed to construct and
                aggregate messages, and to update node embeddings.
        """
        try:
            size = self.__check_input__(edge_index, size)
            coll_dict = self.__collect__(self.__user_args__,
                                         edge_index, size, kwargs)
        except AttributeError:
            size = self._check_input(edge_index, size)
            coll_dict = self._collect(self._user_args,
                                      edge_index, size, kwargs)

        msg_kwargs = self.inspector.distribute('message', coll_dict)
        aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
        update_kwargs = self.inspector.distribute('update', coll_dict)

        #  get messages
        m_ij = self.message(**msg_kwargs)

        # update coors if specified
        if self.update_coors:
            coor_wij = self.coors_mlp(m_ij)
            # clamp if arg is set
            if self.coor_weights_clamp_value:
                coor_weights_clamp_value = self.coor_weights_clamp_value
                # coor_weights.clamp_(min = -clamp_value, max = clamp_value)

            # normalize if needed
            kwargs["rel_coors"] = self.coors_norm(kwargs["rel_coors"])

            mhat_i = self.aggregate(coor_wij * kwargs["rel_coors"], **aggr_kwargs)
            coors_out = kwargs["coors"] + mhat_i
        else:
            coors_out = kwargs["coors"]

        # update feats if specified
        if self.update_feats:
            # weight the edges if arg is passed
            if self.soft_edge:
                m_ij = m_ij * self.edge_weight(m_ij)
            m_i = self.aggregate(m_ij, **aggr_kwargs)

            hidden_feats = self.node_norm(kwargs["x"], kwargs["batch"]) if self.node_norm else kwargs["x"]
            hidden_out = self.node_mlp(torch.cat([hidden_feats, m_i], dim=-1))
            hidden_out = kwargs["x"] + hidden_out

        else:
            hidden_out = kwargs["x"]

        # return tuple
        return self.update((hidden_out, coors_out), **update_kwargs)

    def __repr__(self):
        dict_print = {}
        return "E(n)-GNN Layer for Graphs " + str(self.__dict__)


class EGNN_Sparse_Network(nn.Module):
    r"""Sample GNN model architecture that uses the EGNN-Sparse
        message passing layer to learn over point clouds.
        Main MPNN layer introduced in https://arxiv.org/abs/2102.09844v1

        Inputs will be standard GNN: x, edge_index, edge_attr, batch, ...

        Args:
        * n_layers: int. number of MPNN layers
        * ... : same interpretation as the base layer.
        * embedding_nums: list. number of unique keys to embedd. for points
                          1 entry per embedding needed.
        * embedding_dims: list. point - number of dimensions of
                          the resulting embedding. 1 entry per embedding needed.
        * edge_embedding_nums: list. number of unique keys to embedd. for edges.
                               1 entry per embedding needed.
        * edge_embedding_dims: list. point - number of dimensions of
                               the resulting embedding. 1 entry per embedding needed.
        * recalc: int. Recalculate edge feats every `recalc` MPNN layers. 0 for no recalc
        * verbose: bool. verbosity level.
        -----
        Diff with normal layer: one has to do preprocessing before (radius, global token, ...)
    """

    def __init__(self, n_layers, feats_dim,
                 pos_dim=3,
                 edge_attr_dim=0,
                 m_dim=16,
                 fourier_features=0,
                 soft_edge=0,
                 embedding_nums=[],
                 embedding_dims=[],
                 edge_embedding_nums=[],
                 edge_embedding_dims=[],
                 update_coors=True,
                 update_feats=True,
                 norm_feats=True,
                 norm_coors=False,
                 norm_coors_scale_init=1e-2,
                 dropout=0.,
                 coor_weights_clamp_value=None,
                 aggr="add",
                 global_linear_attn_every=0,
                 global_linear_attn_heads=8,
                 global_linear_attn_dim_head=64,
                 num_global_tokens=4,
                 recalc=0, ):
        super().__init__()

        self.n_layers = n_layers

        # Embeddings? solve here
        self.embedding_nums = embedding_nums
        self.embedding_dims = embedding_dims
        self.emb_layers = nn.ModuleList()
        self.edge_embedding_nums = edge_embedding_nums
        self.edge_embedding_dims = edge_embedding_dims
        self.edge_emb_layers = nn.ModuleList()

        # instantiate point and edge embedding layers

        for i in range(len(self.embedding_dims)):
            self.emb_layers.append(nn.Embedding(num_embeddings=embedding_nums[i],
                                                embedding_dim=embedding_dims[i]))
            feats_dim += embedding_dims[i] - 1

        for i in range(len(self.edge_embedding_dims)):
            self.edge_emb_layers.append(nn.Embedding(num_embeddings=edge_embedding_nums[i],
                                                     embedding_dim=edge_embedding_dims[i]))
            edge_attr_dim += edge_embedding_dims[i] - 1
        # rest
        self.mpnn_layers = nn.ModuleList()
        self.feats_dim = feats_dim
        self.pos_dim = pos_dim
        self.edge_attr_dim = edge_attr_dim
        self.m_dim = m_dim
        self.fourier_features = fourier_features
        self.soft_edge = soft_edge
        self.norm_feats = norm_feats
        self.norm_coors = norm_coors
        self.norm_coors_scale_init = norm_coors_scale_init
        self.update_feats = update_feats
        self.update_coors = update_coors
        self.dropout = dropout
        self.coor_weights_clamp_value = coor_weights_clamp_value
        self.recalc = recalc

        self.has_global_attn = global_linear_attn_every > 0
        self.global_tokens = None
        self.global_linear_attn_every = global_linear_attn_every
        if self.has_global_attn:
            self.global_tokens = nn.Parameter(torch.randn(num_global_tokens, self.feats_dim))

        # instantiate layers
        for i in range(n_layers):
            layer = EGNN_Sparse(feats_dim=feats_dim,
                                pos_dim=pos_dim,
                                edge_attr_dim=edge_attr_dim,
                                m_dim=m_dim,
                                fourier_features=fourier_features,
                                soft_edge=soft_edge,
                                norm_feats=norm_feats,
                                norm_coors=norm_coors,
                                norm_coors_scale_init=norm_coors_scale_init,
                                update_feats=update_feats,
                                update_coors=update_coors,
                                dropout=dropout,
                                coor_weights_clamp_value=coor_weights_clamp_value)

            # global attention case
            is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0
            if is_global_layer:
                attn_layer = GlobalLinearAttention_Sparse(dim=self.feats_dim,
                                                          heads=global_linear_attn_heads,
                                                          dim_head=global_linear_attn_dim_head)
                self.mpnn_layers.append(nn.ModuleList([attn_layer, layer]))
            # normal case
            else:
                self.mpnn_layers.append(layer)

    def forward(self, x, edge_index, batch, edge_attr,
                bsize=None, recalc_edge=None, verbose=0):
        """ Recalculate edge features every `self.recalc_edge` with the
            `recalc_edge` function if self.recalc_edge is set.

            * x: (N, pos_dim+feats_dim) will be unpacked into coors, feats.
        """
        # NODES - Embedd each dim to its target dimensions:
        x = embedd_token(x, self.embedding_dims, self.emb_layers)

        #  regulates wether to embedd edges each layer
        edges_need_embedding = False
        for i, layer in enumerate(self.mpnn_layers):

            # EDGES - Embedd each dim to its target dimensions:
            if edges_need_embedding:
                edge_attr = embedd_token(edge_attr, self.edge_embedding_dims, self.edge_emb_layers)
                edges_need_embedding = False

            #  attn tokens
            self.global_tokens = None
            if exists(self.global_tokens):
                unique, amounts = torch.unique(batch, return_counts=True)
                num_idxs = torch.cat(
                    [torch.arange(num_idxs_i, device=self.global_tokens.device) for num_idxs_i in amounts], dim=-1)
                global_tokens = self.global_tokens[num_idxs]

            #  pass layers
            is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0
            if not is_global_layer:
                x = layer(x, edge_index, edge_attr, batch=batch, size=bsize)
            else:
                # only pass feats to the attn layer
                # unique, amounts = torch.unique(batch, return_counts=True)
                x_attn = layer[0](x[:, self.pos_dim:], x[:, self.pos_dim:], batch)[0]  # global_tokens
                #  merge attn-ed feats and coords
                x = torch.cat((x[:, :self.pos_dim], x_attn), dim=-1)
                x = layer[-1](x, edge_index, edge_attr, batch=batch, size=bsize)

            # recalculate edge info - not needed if last layer
            if self.recalc and ((i % self.recalc == 0) and not (i == len(self.mpnn_layers) - 1)):
                edge_index, edge_attr, _ = recalc_edge(x)  #  returns attr, idx, any_other_info
                edges_need_embedding = True

        return x

    def __repr__(self):
        return 'EGNN_Sparse_Network of: {0} layers'.format(len(self.mpnn_layers))