from typing import Optional, Tuple, Union, Dict, List
from functools import partial

import torch
from torch import nn, LongTensor, Tensor
from torch.nn import functional as F
from torch_geometric.data import Data
from torch_geometric.nn import global_add_pool, global_mean_pool
from torch_scatter import scatter
import e3nn

class BaseMLP(nn.Module):
    def __init__(
        self, 
        input_dim: int, 
        hidden_dim: int, 
        output_dim: int, 
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        residual: bool = False, 
        last_act: bool = False,
    ) -> None:
        super(BaseMLP, self).__init__()
        self.residual = residual
        if residual:
            assert output_dim == input_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Identity() if norm is None else norm(hidden_dim),
            activation,
            nn.Linear(hidden_dim, output_dim),
            nn.Identity() if norm is None else norm(output_dim),
            activation if last_act else nn.Identity()
        )

    def forward(self, x):
        return x + self.mlp(x) if self.residual else self.mlp(x)
    
class VNInitial(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 64,
        vn_channel: int = 4,
        edge_attr_dim: int = 2,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        device: str = 'cpu',
    ) -> None:
        super().__init__()
        self.vn_channel = vn_channel

        self.diff_vec_coff = e3nn.o3.FullyConnectedTensorProduct(
            '2x1o', '1x0e', f'{self.vn_channel}x1o', shared_weights=False
        )
        
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation, norm=norm)
        self.mlp_feat_init = MLP(input_dim=hidden_dim, output_dim=hidden_dim, last_act=True)
        self.mlp_msg = MLP(input_dim=2 * hidden_dim + edge_attr_dim + 4, output_dim=hidden_dim)
        self.mlp_diff_vel_coff = MLP(input_dim=hidden_dim, output_dim=self.diff_vec_coff.weight_numel)
        self.mlp_vn_feat = MLP(input_dim=hidden_dim, output_dim=hidden_dim)
        
        
    def forward(self, node_feat, node_pos, node_vel, edge_index, edge_attr):
        vn_feat = self.mlp_feat_init(node_feat)
        vn_pos = node_pos.repeat(1, self.vn_channel)
        
        row, col = edge_index
        diff_pos = node_pos[row] - node_pos[col]
        diff_vel = node_vel[row] - node_vel[col]
        
        diff_vec = torch.cat([diff_pos, diff_vel], dim=1)
        ip = torch.einsum('bij,bkj->bik', diff_vec.view(-1, 2, 3), diff_vec.view(-1, 2, 3)).view(-1, 4)
        ip = ip / (torch.norm(ip) + 1e-3)
        diff_vec = diff_vec / (torch.norm(ip) + 1e-3)
        
        one = torch.ones([diff_vec.size(0), 1], device=diff_vec.device)
        msg = torch.cat([node_feat[row], node_feat[col], edge_attr, ip], dim=1)
        msg = self.mlp_msg(msg)
        diff_vec = self.diff_vec_coff(diff_vec, one, self.mlp_diff_vel_coff(msg))

        agg_feat = scatter(src=msg, index=row, dim=0, dim_size=node_feat.size(0), reduce='mean')
        agg_vec = scatter(src=diff_vec, index=row, dim=0, dim_size=node_feat.size(0), reduce='mean')
        vn_feat = vn_feat + self.mlp_vn_feat(agg_feat)
        vn_pos = vn_pos + agg_vec
        return vn_feat, vn_pos
    
class VNLayer(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 64,
        vn_channel: int = 4,
        edge_attr_dim: int = 2,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        device: str = 'cpu',
    ) -> None:
        super().__init__()
        self.vn_channel = vn_channel

        self.diff_pos_vr_coff = e3nn.o3.FullyConnectedTensorProduct(
            f'{self.vn_channel}x1o', '1x0e', '1x1o', shared_weights=False
        )
        self.diff_pos_rv_coff = e3nn.o3.FullyConnectedTensorProduct(
            f'{self.vn_channel}x1o', '1x0e', f'{self.vn_channel}x1o', shared_weights=False
        )
        
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation, norm=norm)
        self.mlp_com_msg = MLP(input_dim=2 * hidden_dim + vn_channel ** 2, output_dim=hidden_dim)
        self.mlp_diff_pos_vr_coff = MLP(input_dim=hidden_dim, output_dim=self.diff_pos_vr_coff.weight_numel)
        self.mlp_diff_pos_rv_coff = MLP(input_dim=hidden_dim, output_dim=self.diff_pos_rv_coff.weight_numel)
        self.mlp_feat = MLP(input_dim=hidden_dim, output_dim=hidden_dim)
        self.mlp_vn_feat = MLP(input_dim=hidden_dim, output_dim=hidden_dim)
        
    def forward(self, node_feat, node_pos, vn_feat, vn_pos, edge_index, edge_attr):
        row, col = edge_index
        diff_pos_vr = node_pos.repeat(1, self.vn_channel)[row] - vn_pos[col]
        ip = torch.einsum('bij,bkj->bik', diff_pos_vr.view(-1, self.vn_channel, 3), diff_pos_vr.view(-1, self.vn_channel, 3)).view(-1, self.vn_channel ** 2)
        ip = ip / (torch.norm(ip) + 1e-3)
        diff_pos_vr = diff_pos_vr / (torch.norm(ip) + 1e-3)
        diff_pos_rv = - diff_pos_vr
        
        com_msg_vr = torch.cat([node_feat[row], vn_feat[col], ip], dim=1)
        com_msg_vr = self.mlp_com_msg(com_msg_vr)
        
        one_vr = torch.ones([diff_pos_vr.size(0), 1], device=diff_pos_vr.device)
        diff_pos_vr = self.diff_pos_vr_coff(diff_pos_vr, one_vr, self.mlp_diff_pos_vr_coff(com_msg_vr))
        
        one_rv = torch.ones([diff_pos_rv.size(0), 1], device=diff_pos_rv.device)
        diff_pos_rv = self.diff_pos_rv_coff(diff_pos_rv, one_rv, self.mlp_diff_pos_rv_coff(com_msg_vr))
        
        agg_feat_vr = scatter(src=com_msg_vr, index=row, dim=0, dim_size=node_feat.size(0), reduce='mean')
        agg_pos_vr = scatter(src=diff_pos_vr, index=row, dim=0, dim_size=node_feat.size(0), reduce='mean')
        agg_pos_rv = scatter(src=diff_pos_rv, index=row, dim=0, dim_size=node_feat.size(0), reduce='mean')
        
        node_feat = node_feat + self.mlp_feat(agg_feat_vr)
        node_pos = node_pos + agg_pos_vr
        
        vn_feat = vn_feat + self.mlp_vn_feat(agg_feat_vr)
        vn_pos = vn_pos + agg_pos_rv
        return node_feat, node_pos, vn_feat, vn_pos
    

class MyLayer(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 64,
        edge_attr_dim: int = 2,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation, norm=norm)
        self.mlp_msg = MLP(input_dim=2 * hidden_dim + edge_attr_dim + 1, output_dim=hidden_dim, last_act=True)
        self.mlp_pos = MLP(input_dim=hidden_dim, output_dim=1)
        self.mlp_node_feat = MLP(input_dim=hidden_dim + hidden_dim, output_dim=hidden_dim)
        self.mlp_vel = MLP(input_dim=hidden_dim, output_dim=1)

    def forward(
        self,
        node_feat: Tensor, 
        node_pos: Tensor, 
        node_vel: Optional[Tensor], 
        edge_index: List[LongTensor], 
        edge_attr: Optional[Tensor],
    ) -> Tuple[Tensor, Tensor]:
        msg, diff_pos = self.Msg(edge_index, edge_attr, node_feat, node_pos)
        msg_agg, pos_agg = self.Agg(edge_index, node_feat.size(0), msg, diff_pos)
        node_feat, node_pos = self.Upd(node_feat, node_pos, node_vel, msg_agg, pos_agg)
        return node_feat, node_pos

    def Msg(self, edge_index, edge_attr, node_feat, node_pos):
        row, col = edge_index
        diff_pos = node_pos[row] - node_pos[col]
        dist = torch.norm(diff_pos, p=2, dim=-1).unsqueeze(-1) ** 2
        
        msg = torch.cat([i for i in [node_feat[row], node_feat[col], edge_attr, dist] if i is not None], dim=-1)
        msg = self.mlp_msg(msg)
        diff_pos = diff_pos * self.mlp_pos(msg)

        return msg, diff_pos
    
    def Agg(self, edge_index, dim_size, msg, diff_pos):
        row, col = edge_index
        msg_agg = scatter(src=msg, index=row, dim=0, dim_size=dim_size, reduce='mean')
        pos_agg = scatter(src=diff_pos, index=row, dim=0, dim_size=dim_size, reduce='mean')
        return msg_agg, pos_agg
    
    def Upd(self, node_feat, node_pos, node_vel, msg_agg, pos_agg):
        node_pos = node_pos + pos_agg
        if node_vel is not None:
            node_pos = node_pos + self.mlp_vel(node_feat) * node_vel

        node_feat = torch.cat([node_feat, msg_agg], dim=-1)
        node_feat = self.mlp_node_feat(node_feat)
        return node_feat, node_pos

class EGNNModel_cpl_local(torch.nn.Module):
    def __init__(
        self,
        num_layer: int = 4,
        hidden_dim: int = 64,
        vn_channel: int = 4,
        node_input_dim: int = 2,
        edge_attr_dim: int = 2,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        device: str = 'cpu',
    ) -> None:
        super().__init__()
        self.embedding = nn.Linear(node_input_dim, hidden_dim)
        self.vn_init = VNInitial(hidden_dim, vn_channel, edge_attr_dim)

        self.layers = torch.nn.ModuleList()
        for _ in range(num_layer):
            self.layers.append(MyLayer(hidden_dim, edge_attr_dim, activation, norm))
            
        self.vn_layers = torch.nn.ModuleList()
        for _ in range(num_layer):
            self.vn_layers.append(VNLayer(hidden_dim, vn_channel, edge_attr_dim, activation, norm))
            
        self.to(device)

    def forward(self, data: Data) -> Tensor:
        node_feat = self.embedding(data.node_feat)
        node_pos = data.node_pos
        node_vel = data.node_vel if 'node_vel' in data else (data.node_pos * 0).detach()

        edge_index = data.edge_index
        edge_attr = data.edge_attr if 'edge_attr' in data else None

        vn_feat, vn_pos = self.vn_init(node_feat, node_pos, node_vel, edge_index, edge_attr)
        
        for layer, vn_layer in zip(self.layers, self.vn_layers):
            node_feat, node_pos = layer(node_feat, node_pos, node_vel, edge_index, edge_attr)
            node_feat, node_pos, vn_feat, vn_pos = vn_layer(node_feat, node_pos, vn_feat, vn_pos, edge_index, edge_attr)
        return node_pos