from typing import Optional, Tuple, Union, Dict, List
from functools import partial

import torch
from torch import nn, LongTensor, Tensor
from torch.nn import functional as F
from torch_geometric.data import Data
from torch_geometric.nn import global_add_pool, global_mean_pool
from torch_scatter import scatter
import e3nn

from models.mace_modules.blocks import RadialEmbeddingBlock
from models.mace_modules.irreps_tools import irreps2gate

class BaseMLP(nn.Module):
    def __init__(
        self, 
        input_dim: int, 
        hidden_dim: int, 
        output_dim: int, 
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        residual: bool = False, 
        last_act: bool = False,
    ) -> None:
        super(BaseMLP, self).__init__()
        self.residual = residual
        if residual:
            assert output_dim == input_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Identity() if norm is None else norm(hidden_dim),
            activation,
            nn.Linear(hidden_dim, output_dim),
            nn.Identity() if norm is None else norm(output_dim),
            activation if last_act else nn.Identity()
        )

    def forward(self, x):
        return x + self.mlp(x) if self.residual else self.mlp(x)

class NodeColor(nn.Module):
    def __init__(self, hidden_dim, color_type='center_radius', max_ell=6, activation=nn.SiLU()):
        super().__init__()
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation)
        if color_type == 'mp':
            self.mlp_msg = MLP(input_dim=hidden_dim * 2 + 1, output_dim=hidden_dim)
            self.mlp_node_feat = MLP(input_dim=hidden_dim, output_dim=hidden_dim)
        elif color_type == 'center_radius':
            self.mlp_node_feat = MLP(input_dim=1, output_dim=hidden_dim)
        elif color_type == 'tp':
            sh_irreps = e3nn.o3.Irreps.spherical_harmonics(max_ell)
            self.spherical_harmonics = e3nn.o3.SphericalHarmonics(
                sh_irreps, normalize=True, normalization="norm"
            )
            self.tp = e3nn.o3.FullyConnectedTensorProduct(sh_irreps, sh_irreps, f'{max_ell + 1}x0e',  shared_weights=False)
            
            self.mlp_sh_coff = MLP(input_dim=hidden_dim, output_dim=self.tp.weight_numel)
            self.mlp_node_feat = MLP(input_dim=max_ell + 1, output_dim=hidden_dim)
        self.color_type = color_type
        
    def forward(self, node_feat, node_pos, batch, edge_index=None, edge_attr=None):
        center = global_mean_pool(node_pos, batch)
        pos = node_pos - center[batch]
        
        if self.color_type == 'mp':
            assert edge_index is not None
            row, col = edge_index
            dist = torch.norm(node_pos[row]-node_pos[col], dim=1, keepdim=True)
            msg = torch.cat([node_feat[row], node_feat[col], dist], dim=1)
            msg = self.mlp_msg(msg)
            scalar = scatter(src=msg, index=row, dim=0, dim_size=node_feat.size(0), reduce='mean')
        elif self.color_type == 'center_radius':
            scalar = torch.norm(pos, dim=1, keepdim=True)
        elif self.color_type == 'tp':
            sh = self.spherical_harmonics(pos)
            global_sh = global_mean_pool(sh, batch)
            scalar = self.tp(sh, global_sh[batch], self.mlp_sh_coff(node_feat))
        else:
            raise NotImplementedError
            
        return self.mlp_node_feat(scalar)

class VirtualNode(nn.Module):
    def __init__(self, vn_channel=4, hidden_dim=64, activation=nn.SiLU()):
        super().__init__()
        self.vn_channel = vn_channel
        self.get_vn_pos = e3nn.o3.FullyConnectedTensorProduct(
            '1x1o', '1x0e', f'{vn_channel}x1o', shared_weights=False
        )
        
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation)
        self.mlp_vec_coff = MLP(input_dim=hidden_dim, output_dim=self.get_vn_pos.weight_numel)
    
    def forward(self, node_feat, node_pos, batch):
        center = global_mean_pool(node_pos, batch)
        pos = node_pos - center[batch]
        one = torch.ones([pos.size(0), 1], device=pos.device)

        vn_pos = global_mean_pool(
            self.get_vn_pos(pos, one, self.mlp_vec_coff(node_feat)), batch
        )
        vn_pos = vn_pos.view(-1, self.vn_channel, 3)
        vn_pos = vn_pos / (torch.norm(vn_pos, dim=2, keepdim=True) + 1e-3) 
        vn_pos = vn_pos.view(-1, self.vn_channel * 3)
        vn_pos = vn_pos + center.repeat(1, self.vn_channel)
        
        return vn_pos
    
class NodeFeatByVN(nn.Module):
    def __init__(self, vn_channel=4, hidden_dim=64, activation=nn.SiLU()):
        super().__init__()
        self.vn_channel = vn_channel
        
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation)
        self.mlp_node_feat = MLP(input_dim=vn_channel ** 2, output_dim=hidden_dim)
        
    def forward(self, node_feat, node_pos, vn_pos, batch):
        info_vec = node_pos.repeat(1, self.vn_channel) - vn_pos[batch]
        info_vec = info_vec.view(node_pos.size(0), self.vn_channel, 3)
        info_scalar = torch.cdist(info_vec, info_vec).view(node_pos.size(0), self.vn_channel ** 2)
        info_scalar = info_scalar / (torch.norm(info_scalar, dim=1, keepdim=True) + 1e-3)
        
        return self.mlp_node_feat(info_scalar)
    

class EGNNLayer(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 64,
        edge_attr_dim: int = 2,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
    ) -> None:
        super(EGNNLayer, self).__init__()
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation, norm=norm)
        self.mlp_msg = MLP(input_dim=2 * hidden_dim + edge_attr_dim + 1, output_dim=hidden_dim, last_act=True)
        self.mlp_pos = MLP(input_dim=hidden_dim, output_dim=1)
        self.mlp_node_feat = MLP(input_dim=hidden_dim + hidden_dim, output_dim=hidden_dim)
        self.mlp_vel = MLP(input_dim=hidden_dim, output_dim=1)

    def forward(
        self,
        node_feat: Tensor, 
        node_pos: Tensor, 
        node_vel: Optional[Tensor], 
        edge_index: List[LongTensor], 
        edge_attr: Optional[Tensor],
    ) -> Tuple[Tensor, Tensor]:
        msg, diff_pos = self.Msg(edge_index, edge_attr, node_feat, node_pos)
        msg_agg, pos_agg = self.Agg(edge_index, node_feat.size(0), msg, diff_pos)
        node_feat, node_pos = self.Upd(node_feat, node_pos, node_vel, msg_agg, pos_agg)
        return node_feat, node_pos

    def Msg(self, edge_index, edge_attr, node_feat, node_pos):
        row, col = edge_index
        diff_pos = node_pos[row] - node_pos[col]
        dist = torch.norm(diff_pos, p=2, dim=-1).unsqueeze(-1) ** 2
        
        msg = torch.cat([i for i in [node_feat[row], node_feat[col], edge_attr, dist] if i is not None], dim=-1)
        msg = self.mlp_msg(msg)
        diff_pos = diff_pos * self.mlp_pos(msg)

        return msg, diff_pos
    
    def Agg(self, edge_index, dim_size, msg, diff_pos):
        row, col = edge_index
        msg_agg = scatter(src=msg, index=row, dim=0, dim_size=dim_size, reduce='mean')
        pos_agg = scatter(src=diff_pos, index=row, dim=0, dim_size=dim_size, reduce='mean')
        return msg_agg, pos_agg
    
    def Upd(self, node_feat, node_pos, node_vel, msg_agg, pos_agg):
        node_pos = node_pos + pos_agg
        if node_vel is not None:
            node_pos = node_pos + self.mlp_vel(node_feat) * node_vel

        node_feat = torch.cat([node_feat, msg_agg], dim=-1)
        node_feat = self.mlp_node_feat(node_feat)
        return node_feat, node_pos
    
class EGNNModel(torch.nn.Module):
    def __init__(
        self,
        num_layer: int = 4,
        hidden_dim: int = 64,
        node_input_dim: int = 2,
        edge_attr_dim: int = 2,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        device: str = 'cpu',
    ) -> None:
        super(EGNNModel, self).__init__()
        self.embedding = nn.Linear(node_input_dim, hidden_dim)

        self.layers = torch.nn.ModuleList()
        for _ in range(num_layer):
            self.layers.append(EGNNLayer(hidden_dim, edge_attr_dim, activation, norm))
            
        self.to(device)

    def forward(self, data: Data) -> Tuple[Tensor, Tensor]:
        node_feat = self.embedding(data.node_feat)
        node_pos = data.node_pos
        node_vel = data.node_vel if 'node_vel' in data else None

        edge_index = data.edge_index
        edge_attr = data.edge_attr if 'edge_attr' in data else None

        for layer in self.layers:
            node_feat, node_pos = layer(node_feat, node_pos, node_vel, edge_index, edge_attr)
            
        return node_feat, node_pos

class TensorProductConvLayer(nn.Module):
    def __init__(
        self,
        in_irreps: e3nn.o3.Irreps,
        out_irreps: e3nn.o3.Irreps,
        sh_irreps: e3nn.o3.Irreps,
        hidden_dim: int = 64,
        edge_attr_dim: int = 2,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
    ):
        super().__init__()

        self.tp = e3nn.o3.FullyConnectedTensorProduct(
            in_irreps, sh_irreps, out_irreps, shared_weights=False
        )
        self.re_scale = e3nn.o3.FullyConnectedTensorProduct(
            out_irreps, '1x0e', out_irreps, shared_weights=False
        )
        self.get_scalar = e3nn.o3.Linear(out_irreps, f'{hidden_dim}x0e')

        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation, norm=norm)
        self.mlp_sh_coff = MLP(input_dim=edge_attr_dim, output_dim=self.tp.weight_numel)
        self.mlp_cpl_coff = MLP(input_dim=hidden_dim, output_dim=self.re_scale.weight_numel)
        self.mlp_msg_cpl = MLP(input_dim=2*hidden_dim, output_dim=hidden_dim)
        
    def forward(self, node_feat, edge_index, edge_attr, edge_sh, msg_cpl):
        col, row = edge_index
        msg = self.tp(node_feat[col], edge_sh, self.mlp_sh_coff(edge_attr))
        agg = scatter(src=msg, index=row, dim=0, dim_size=node_feat.size(0), reduce='mean')
        node_feat = node_feat + agg

        one = torch.ones([node_feat.size(0), 1], device=node_feat.device)
        node_feat = self.re_scale(node_feat, one, self.mlp_cpl_coff(msg_cpl))

        msg_cpl = msg_cpl + self.mlp_msg_cpl(torch.cat([self.get_scalar(node_feat), msg_cpl], dim=-1))

        return node_feat, msg_cpl


class TFNModel_cpl_global(torch.nn.Module):
    """
    Tensor Field Network model from "Tensor Field Networks".
    """
    def __init__(
        self,
        r_max: float = 10.0,
        num_bessel: int = 8,
        num_polynomial_cutoff: int = 5,
        max_ell: int = 2,
        require_vel: bool = True,
        num_layer: int = 4,
        hidden_dim: int = 64,
        irreps_channels: int = 8,
        node_input_dim: int = 2,
        edge_attr_dim: int = 2,
        vn_channel: int = 4,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        device: str = 'cpu',
    ):
        super().__init__()
        
        sh_irreps = e3nn.o3.Irreps.spherical_harmonics(max_ell)
        self.spherical_harmonics = e3nn.o3.SphericalHarmonics(
            sh_irreps, normalize=True, normalization="norm"
        )
        self.input_irreps = e3nn.o3.Irreps(f'{irreps_channels}x0e+1x1o') if require_vel else e3nn.o3.Irreps(f'{irreps_channels}x0e')
        self.hidden_irreps = (sh_irreps * irreps_channels).sort()[0].simplify()
        
        self.radial_embedding = RadialEmbeddingBlock(
            r_max=r_max,
            num_bessel=num_bessel,
            num_polynomial_cutoff=num_polynomial_cutoff,
        )
        
        self.embedding = nn.Linear(node_input_dim, irreps_channels)

        self.vn_channel = vn_channel
        self.node_color = NodeColor(hidden_dim=irreps_channels)
        self.vn = VirtualNode(vn_channel=vn_channel, hidden_dim=irreps_channels)
        self.node_feat_by_vn = NodeFeatByVN(vn_channel=vn_channel, hidden_dim=hidden_dim)
        
        # self.com_node_feat_net = EGNNModel(num_layer=4, hidden_dim=hidden_dim, node_input_dim=node_input_dim, edge_attr_dim=edge_attr_dim)
        self.com_node_feat_net = EGNNModel(num_layer=1, hidden_dim=hidden_dim, node_input_dim=node_input_dim, edge_attr_dim=edge_attr_dim)


        self.layers = torch.nn.ModuleList()
        for _ in range(num_layer):
            self.layers.append(
                TensorProductConvLayer(
                    in_irreps=self.hidden_irreps,
                    out_irreps=self.hidden_irreps,
                    sh_irreps=sh_irreps,
                    hidden_dim=hidden_dim,
                    edge_attr_dim=self.radial_embedding.out_dim + edge_attr_dim,
                    activation=activation,
                    norm=norm,
                )
            )

        self.get_delta_pos = e3nn.o3.Linear(self.hidden_irreps, '1x1o')

    
    def forward(self, data: Data):
        node_feat = self.embedding(data.node_feat)
        node_pos = data.node_pos
        node_vel = data.node_vel if 'node_vel' in data else None

        edge_index = data.edge_index
        col, row = edge_index

        node_feat = node_feat + self.node_color(node_feat, node_pos, data.batch, edge_index)
        vn_pos = self.vn(node_feat, node_pos, data.batch)
        msg_cpl = self.node_feat_by_vn(data.node_feat, node_pos, vn_pos, data.batch)

        node_feat = torch.cat([i for i in [node_feat, node_vel] if i is not None], dim=-1)

        diff_pos = node_pos[col] - node_pos[row]
        dist = torch.norm(diff_pos, dim=1, keepdim=True)

        edge_attr = torch.cat([i for i in [
            data.edge_attr if 'edge_attr' in data else None,
            self.radial_embedding(dist)
        ] if i is not None], dim=-1)
        edge_sh = self.spherical_harmonics(diff_pos)
        
        node_feat = F.pad(node_feat, (0, self.hidden_irreps.dim - self.input_irreps.dim))
        for layer in self.layers:
            node_feat, msg_cpl = layer(node_feat, edge_index, edge_attr, edge_sh, msg_cpl)
        node_pos = node_pos + self.get_delta_pos(node_feat)
        
        return node_pos