from functools import partial

import torch
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool

from models.layers.egnn_layer import EGNNLayer

import e3nn

class BaseMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, activation=nn.SiLU(), residual=False, last_act=False):
        super(BaseMLP, self).__init__()
        self.residual = residual
        if residual:
            assert output_dim == input_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            activation,
            nn.Linear(hidden_dim, output_dim),
            activation if last_act else nn.Identity()
        )
    def forward(self, x):
        return x + self.mlp(x) if self.residual else self.mlp(x)
    
class NodeColor(nn.Module):
    def __init__(self, hidden_dim, color_type='center_radius', max_ell=6, activation=nn.SiLU()):
        super().__init__()
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation)
        if color_type == 'center_radius':
            self.mlp_node_feat = MLP(input_dim=1, output_dim=hidden_dim)
        elif color_type == 'tp':
            sh_irreps = e3nn.o3.Irreps.spherical_harmonics(max_ell)
            self.spherical_harmonics = e3nn.o3.SphericalHarmonics(
                sh_irreps, normalize=True, normalization="norm"
            )
            self.tp = e3nn.o3.FullyConnectedTensorProduct(sh_irreps, sh_irreps, f'{max_ell + 1}x0e',  shared_weights=False)
            
            self.mlp_sh_coff = MLP(input_dim=hidden_dim, output_dim=self.tp.weight_numel)
            self.mlp_node_feat = MLP(input_dim=max_ell + 1, output_dim=hidden_dim)
        self.color_type = color_type
        
    def forward(self, node_feat, node_pos, batch):
        center = global_mean_pool(node_pos, batch)
        pos = node_pos - center[batch]
        
        if self.color_type == 'center_radius':
            scalar = torch.norm(pos, dim=1, keepdim=True)
        elif self.color_type == 'tp':
            sh = self.spherical_harmonics(pos)
            global_sh = global_mean_pool(sh, batch)
            scalar = self.tp(sh, global_sh[batch], self.mlp_sh_coff(node_feat))
            
        return self.mlp_node_feat(scalar)

class VirtualNode(nn.Module):
    def __init__(self, num_vn=4, length=5.0, hidden_dim=64, activation=nn.SiLU()):
        super().__init__()
        self.num_vn = num_vn
        self.length = length
        self.get_vn_pos = e3nn.o3.FullyConnectedTensorProduct(
            '1x1o', '1x0e', f'{num_vn}x1o', shared_weights=False
        )
        
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation)
        self.mlp_vec_coff = MLP(input_dim=hidden_dim, output_dim=self.get_vn_pos.weight_numel)
    
    def forward(self, node_feat, node_pos, batch):
        center = global_mean_pool(node_pos, batch)
        pos = node_pos - center[batch]
        one = torch.ones([pos.size(0), 1], device=pos.device)

        vn_pos = global_mean_pool(
            self.get_vn_pos(pos, one, self.mlp_vec_coff(node_feat)), batch
        )
        vn_pos = vn_pos.view(-1, self.num_vn, 3)
        vn_pos = vn_pos / (torch.norm(vn_pos, dim=2, keepdim=True) + 1e-3) * self.length
        vn_pos = vn_pos.view(-1, self.num_vn * 3)
        vn_pos = vn_pos + center.repeat(1, self.num_vn)
        
        return vn_pos
    
class NodeFeatByVN(nn.Module):
    def __init__(self, num_vn=4, hidden_dim=64, activation=nn.SiLU()):
        super().__init__()
        self.num_vn = num_vn
        
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation)
        self.mlp_node_feat = MLP(input_dim=num_vn ** 2, output_dim=hidden_dim)
        
    def forward(self, node_feat, node_pos, vn_pos, batch):
        info_vec = node_pos.repeat(1, self.num_vn) - vn_pos[batch]
        info_vec = info_vec.view(node_pos.size(0), self.num_vn, 3)
        info_scalar = torch.cdist(info_vec, info_vec).view(node_pos.size(0), self.num_vn ** 2)
        # info_scalar = info_scalar / (torch.norm(info_scalar, dim=1, keepdim=True) + 1e-3)
        
        return self.mlp_node_feat(info_scalar)
    
class CheckChiralByVN(nn.Module):
    def __init__(self, num_vn=4, hidden_dim=64, activation=nn.SiLU()):
        super().__init__()
        self.num_vn = num_vn
        assert num_vn >= 4
        
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation)
        self.mlp_node_feat = MLP(input_dim=1, output_dim=hidden_dim)
        
    def forward(self, vn_pos):
        vn = vn_pos.view(-1, self.num_vn, 3)
        tmp = vn[:, 0:3, :] - vn[:, 3, :].unsqueeze(1)
        return self.mlp_node_feat(torch.det(tmp).unsqueeze(1))
    
class SHNorm(nn.Module):
    def __init__(self, irreps):
        super().__init__()
        self.irreps = irreps
        
    def forward(self, sh):
        temp = torch.zeros_like(sh, device=sh.device)
        idx = 0
        for (mul, ir) in self.irreps:
            for _ in range(mul):
                temp[:, idx:idx + ir.dim] = sh[:, idx:idx + ir.dim] / (torch.norm(sh[:, idx:idx + ir.dim], dim=1, keepdim=True) + 1e-8)
                idx = idx + ir.dim
        return temp
    
class CheckChiralByTP(nn.Module):
    def __init__(self, max_ell=3, channels=4, hidden_dim=64, activation=nn.SiLU()):
        super().__init__()
        sh_irreps = e3nn.o3.Irreps.spherical_harmonics(max_ell)
        self.spherical_harmonics = e3nn.o3.SphericalHarmonics(
            sh_irreps, normalize=True, normalization="norm"
        )
        sym_irreps = (e3nn.o3.Irreps([
            e3nn.o3.Irrep(l, p)
            for l in range(max_ell + 1)
            for p in [-1, 1]
        ]) * channels).sort()[0].simplify()
        self.sh_norm = SHNorm(sym_irreps)
        self.tp0 = e3nn.o3.FullyConnectedTensorProduct(sh_irreps, '1x0e', sym_irreps,  shared_weights=False)
        self.tp1 = e3nn.o3.FullyConnectedTensorProduct(sym_irreps, sym_irreps, sym_irreps)
        self.tp2 = e3nn.o3.FullyConnectedTensorProduct(sym_irreps, sym_irreps, '1x0o')
        
        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation)
        self.mlp_sh_coff = MLP(input_dim=hidden_dim, output_dim=self.tp0.weight_numel)
        self.mlp_node_feat = MLP(input_dim=1, output_dim=hidden_dim)
        
    def forward(self, node_feat, node_pos, batch):
        center = global_mean_pool(node_pos, batch)
        pos = node_pos - center[batch]
        sh = self.spherical_harmonics(pos)
        one = torch.ones([pos.size(0), 1], device=pos.device)

        global_sh = self.sh_norm(global_mean_pool(
            self.tp0(sh, one, self.mlp_sh_coff(node_feat)), batch
        ))
        
        global_sh = self.sh_norm(self.tp1(global_sh, global_sh))
        pseudo_scalar = self.sh_norm(self.tp2(global_sh, global_sh))
        
        return self.mlp_node_feat(pseudo_scalar)
        
class BasicModel_chiral(torch.nn.Module):
    """
    E-GNN model from "E(n) Equivariant Graph Neural Networks".
    """
    def __init__(
        self,
        num_layers: int = 5,
        emb_dim: int = 64,
        in_dim: int = 1,
        out_dim: int = 1,
        num_vn: int = 4,
        color: bool = True,
        tp: bool = False,
        activation: str = "relu",
        norm: str = "layer",
        aggr: str = "sum",
        pool: str = "sum",
        residual: bool = True,
        equivariant_pred: bool = False
    ):
        """
        Initializes an instance of the EGNNModel class with the provided parameters.

        Parameters:
        - num_layers (int): Number of layers in the model (default: 5)
        - emb_dim (int): Dimension of the node embeddings (default: 128)
        - in_dim (int): Input dimension of the model (default: 1)
        - out_dim (int): Output dimension of the model (default: 1)
        - activation (str): Activation function to be used (default: "relu")
        - norm (str): Normalization method to be used (default: "layer")
        - aggr (str): Aggregation method to be used (default: "sum")
        - pool (str): Global pooling method to be used (default: "sum")
        - residual (bool): Whether to use residual connections (default: True)
        - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
        """
        super().__init__()
        self.equivariant_pred = equivariant_pred
        self.residual = residual

        # Embedding lookup for initial node features
        self.emb_in = torch.nn.Embedding(in_dim, emb_dim)
        self.num_vn = num_vn
        self.node_color = NodeColor(hidden_dim=emb_dim)
        self.vn = VirtualNode(num_vn=num_vn, hidden_dim=emb_dim)
        self.node_feat_by_vn = NodeFeatByVN(num_vn=num_vn, hidden_dim=emb_dim)
        self.check_chiral_by_vn = CheckChiralByVN(num_vn=num_vn, hidden_dim=emb_dim)
        self.check_chiral_by_tp = CheckChiralByTP()
        self.color = color
        self.tp = tp
        

        # Global pooling/readout function
        self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

        if self.equivariant_pred:
            # Linear predictor for equivariant tasks using geometric features
            self.pred = torch.nn.Linear(emb_dim + 3, out_dim)
        else:
            # MLP predictor for invariant tasks using only scalar features
            self.pred = torch.nn.Sequential(
                torch.nn.Linear(emb_dim, emb_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(emb_dim, out_dim)
            )
            
        MLP = partial(BaseMLP, hidden_dim=emb_dim, activation=nn.SiLU())

    def forward(self, batch):
        
        h = self.emb_in(batch.atoms)  # (n,) -> (n, d)
        pos = batch.pos  # (n, 3)
        
        if self.color:
            h = h + self.node_color(h, pos, batch.batch)
        vn_pos = self.vn(h, pos, batch.batch)
        h = h + self.node_feat_by_vn(h, pos, vn_pos, batch.batch)
    
        if not self.equivariant_pred:
            # Select only scalars for invariant prediction
            out = self.pool(h, batch.batch)  # (n, d) -> (batch_size, d)
        else:
            out = self.pool(torch.cat([h, pos], dim=-1), batch.batch)
        
        if self.tp:
            out = out + self.check_chiral_by_tp(h, pos, batch.batch)
        else:
            out = out + self.check_chiral_by_vn(vn_pos)
        
        return self.pred(out)  # (batch_size, out_dim)
