from typing import Optional, Tuple, Union, Dict, List
from functools import partial

import torch
from torch import nn, LongTensor, Tensor
from torch.nn import functional as F
from torch_geometric.data import Data, batch
from torch_geometric.nn import global_add_pool, global_mean_pool
from torch_scatter import scatter

class BaseMLP(nn.Module):
    def __init__(
        self, 
        input_dim: int, 
        hidden_dim: int, 
        output_dim: int, 
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        residual: bool = False, 
        last_act: bool = False,
    ) -> None:
        super(BaseMLP, self).__init__()
        self.residual = residual
        if residual:
            assert output_dim == input_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Identity() if norm is None else norm(hidden_dim),
            activation,
            nn.Linear(hidden_dim, output_dim),
            nn.Identity() if norm is None else norm(output_dim),
            activation if last_act else nn.Identity()
        )

    def forward(self, x):
        return x + self.mlp(x) if self.residual else self.mlp(x)


class FastEGNNLayer(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 64,
        vn_channel: int = 4,
        edge_attr_dim: int = 2,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
    ) -> None:
        super(FastEGNNLayer, self).__init__()
        self.vn_channel = vn_channel
        self.hidden_dim = hidden_dim
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation, norm=norm)
        self.mlp_msg = MLP(input_dim=2 * hidden_dim + edge_attr_dim + 1, output_dim=hidden_dim, last_act=True)
        self.mlp_msg_vr = MLP(input_dim=2 * hidden_dim + 1 + vn_channel, output_dim=hidden_dim, last_act=True)

        self.mlp_pos = MLP(input_dim=hidden_dim, output_dim=1)
        self.mlp_pos_vr_r = MLP(input_dim=hidden_dim, output_dim=1)
        self.mlp_pos_vr_v = MLP(input_dim=hidden_dim, output_dim=1)

        self.mlp_node_feat = MLP(input_dim=2 * hidden_dim, output_dim=hidden_dim)
        self.mlp_node_feat_v = MLP(input_dim=2 * hidden_dim, output_dim=hidden_dim)
        self.mlp_vel = MLP(input_dim=hidden_dim, output_dim=1)

    def Msg(self, edge_index, edge_attr, batch, node_feat, node_pos, vn_feat, vn_pos):
        row, col = edge_index
        diff_pos = node_pos[row] - node_pos[col]
        dist = torch.norm(diff_pos, dim=-1, keepdim=True) ** 2

        diff_pos_vr = vn_pos[batch] - node_pos.unsqueeze(-1)
        vr_dist = torch.norm(diff_pos_vr, dim=1, keepdim=True)
        msg = torch.cat([i for i in [node_feat[row], node_feat[col], edge_attr, dist] if i is not None], dim=-1)
        msg = self.mlp_msg(msg)
        diff_pos = diff_pos * self.mlp_pos(msg)

        node_pos_mean = global_mean_pool(node_pos, batch)
        m_X = vn_pos - node_pos_mean.unsqueeze(-1)
        m_X = torch.einsum('bji, bjk -> bik', m_X, m_X)
        msg_vr = torch.cat([
            node_feat.unsqueeze(-1).repeat(1, 1, self.vn_channel),
            vn_feat[batch],
            vr_dist,
            m_X[batch]
        ], dim=1).permute(0, 2, 1)
        msg_vr = self.mlp_msg_vr(msg_vr).permute(0, 2, 1)

        return msg, diff_pos, msg_vr, diff_pos_vr
    
    def Agg(self, edge_index, batch, dim_size, msg, diff_pos, msg_vr, diff_pos_vr):
        row, col = edge_index
        msg_agg = scatter(src=msg, index=row, dim=0, dim_size=dim_size, reduce='mean')
        pos_agg = scatter(src=diff_pos, index=row, dim=0, dim_size=dim_size, reduce='mean')

        diff_pos_vr_r = torch.mean(
            - diff_pos_vr * self.mlp_pos_vr_r(msg_vr.permute(0, 2, 1)).permute(0, 2, 1)
        )
        pos_agg = pos_agg + diff_pos_vr_r

        diff_pos_vr_v = diff_pos_vr * self.mlp_pos_vr_v(msg_vr.permute(0, 2, 1)).permute(0, 2, 1)

        pos_agg_v = global_mean_pool(diff_pos_vr_v.flatten(1), batch).view(-1, 3, self.vn_channel)

        msg_agg_v = global_mean_pool(msg_vr.flatten(1), batch).view(-1, self.hidden_dim, self.vn_channel)

        return msg_agg, pos_agg, msg_agg_v, pos_agg_v
    
    def Upd(self, node_feat, node_pos, node_vel, vn_feat, vn_pos, msg_agg, pos_agg, msg_agg_v, pos_agg_v):
        node_pos = node_pos + pos_agg
        if node_vel is not None:
            node_pos = node_pos + self.mlp_vel(node_feat) * node_vel
        node_pos = node_pos

        node_feat = torch.cat([node_feat, msg_agg], dim=-1)
        node_feat = self.mlp_node_feat(node_feat)

        vn_pos = vn_pos + pos_agg_v
        vn_feat = torch.cat([vn_feat, msg_agg_v], dim=1).permute(0, 2, 1)
        vn_feat = self.mlp_node_feat_v(vn_feat).permute(0, 2, 1)
        
        return node_feat, node_pos, vn_feat, vn_pos

    def forward(self, node_feat, node_pos, node_vel, vn_feat, vn_pos, edge_index, edge_attr, batch):
        msg, diff_pos, msg_vr, diff_pos_vr = self.Msg(edge_index, edge_attr, batch, node_feat, node_pos, vn_feat, vn_pos)
        msg_agg, pos_agg, msg_agg_v, pos_agg_v = self.Agg(edge_index, batch, node_feat.size(0), msg, diff_pos, msg_vr, diff_pos_vr)
        node_feat, node_pos, vn_feat, vn_pos = self.Upd(node_feat, node_pos, node_vel, vn_feat, vn_pos, msg_agg, pos_agg, msg_agg_v, pos_agg_v)

        return node_feat, node_pos, vn_feat, vn_pos


class FastEGNNModel(nn.Module):
    def __init__(
        self,
        num_layer: int = 4,
        hidden_dim: int = 64,
        vn_channel: int = 4,
        node_input_dim: int = 2,
        edge_attr_dim: int = 2,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        device: str = 'cpu',
    ):
        super(FastEGNNModel, self).__init__()
        self.vn_channel = vn_channel
        self.embedding = nn.Linear(node_input_dim, hidden_dim)

        assert vn_channel > 0, f'Channels of virtual node must greater than 0 (got {vn_channel})'
        self.vn_feat = nn.Parameter(data=torch.randn(size=(1, hidden_dim, vn_channel)), requires_grad=True)

        self.layers = torch.nn.ModuleList()
        for _ in range(num_layer):
            self.layers.append(FastEGNNLayer(hidden_dim, vn_channel, edge_attr_dim, activation, norm))

        self.to(device)

    def forward(self, data: Data) -> Tensor:
        node_feat = self.embedding(data.node_feat)
        node_pos = data.node_pos
        node_vel = data.node_vel if 'node_vel' in data else None
        batch = data.batch

        vn_feat = self.vn_feat.repeat(data.num_graphs, 1, 1)
        vn_pos = global_mean_pool(node_pos, batch).unsqueeze(-1).repeat(1, 1, self.vn_channel)

        edge_index = data.edge_index
        edge_attr = data.edge_attr if 'edge_attr' in data else None

        for layer in self.layers:
            node_feat, node_pos, vn_feat, vn_pos = \
                layer(node_feat, node_pos, node_vel, vn_feat, vn_pos, edge_index, edge_attr, batch)
            
        return node_pos