from typing import Optional
from functools import partial

import torch
from torch import nn, LongTensor, Tensor
from torch.nn import functional as F
from torch_geometric.data import Data
from torch_geometric.nn import global_add_pool, global_mean_pool
from torch_scatter import scatter
import e3nn
from e3nn.o3 import Irreps, Linear, FullyConnectedTensorProduct, SphericalHarmonics

class BaseMLP(nn.Module):
    def __init__(
        self, 
        input_dim: int, 
        hidden_dim: int, 
        output_dim: int, 
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        residual: bool = False, 
        last_act: bool = False,
    ) -> None:
        super(BaseMLP, self).__init__()
        self.residual = residual
        if residual:
            assert output_dim == input_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Identity() if norm is None else norm(hidden_dim),
            activation,
            nn.Linear(hidden_dim, output_dim),
            nn.Identity() if norm is None else norm(output_dim),
            activation if last_act else nn.Identity()
        )

    def forward(self, x):
        return x + self.mlp(x) if self.residual else self.mlp(x)
    

class TensorProductConvLayer(nn.Module):
    def __init__(
        self,
        in_irreps: Irreps,
        out_irreps: Irreps,
        sh_irreps: Irreps,
        hidden_dim: int = 64,
        edge_attr_dim: int = 64,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
    ):
        super().__init__()

        self.tp = FullyConnectedTensorProduct(
            in_irreps, sh_irreps, out_irreps, shared_weights=False
        )

        MLP = partial(BaseMLP, hidden_dim=hidden_dim, activation=activation, norm=norm)
        self.mlp_sh_coff = MLP(input_dim=edge_attr_dim, output_dim=self.tp.weight_numel)
        
    def forward(self, node_feat, edge_index, edge_attr, edge_sh):
        col, row = edge_index
        msg = self.tp(node_feat[col], edge_sh, self.mlp_sh_coff(edge_attr))
        agg = scatter(src=msg, index=row, dim=0, dim_size=node_feat.size(0), reduce='mean')
        node_feat = node_feat + agg
        return node_feat


class TFNModel(torch.nn.Module):
    """
    Tensor Field Network model from "Tensor Field Networks".
    """
    def __init__(
        self,
        max_ell: int = 7,
        num_layer: int = 2,
        hidden_dim: int = 64,
        irreps_channels: int = 4,
        node_input_dim: int = 11,
        activation: nn.Module = nn.SiLU(),
        norm: Optional[nn.Module] = None,
        device: str = 'cpu',
    ):
        super().__init__()
        
        sh_irreps = Irreps.spherical_harmonics(max_ell)
        self.spherical_harmonics = SphericalHarmonics(
            sh_irreps, normalize=True, normalization="norm"
        )

        self.irreps_channels = irreps_channels
        self.input_irreps = Irreps(f'{irreps_channels}x0e')
        tmp_irreps = '0e+0o'
        for i in range(max_ell):
            tmp_irreps += f'+{i + 1}e+{i + 1}o'
        tmp_irreps = Irreps(tmp_irreps)
        self.hidden_irreps = (tmp_irreps * irreps_channels).sort()[0].simplify()
        self.output_dim = self.hidden_irreps.dim
        print(self.hidden_irreps)
        
        self.node_embedding = nn.Linear(node_input_dim, irreps_channels)
        self.edge_embedding = nn.Linear(4, hidden_dim)

        self.layers = torch.nn.ModuleList()
        for _ in range(num_layer):
            self.layers.append(
                TensorProductConvLayer(
                    in_irreps=self.hidden_irreps,
                    out_irreps=self.hidden_irreps,
                    sh_irreps=Irreps.spherical_harmonics(max_ell),
                    hidden_dim=hidden_dim,
                    edge_attr_dim=hidden_dim,
                    activation=activation,
                    norm=norm,
                )
            )
        self.to(device)
    
    def forward(self, data: Data):
        node_feat = self.node_embedding(data.x)
        node_feat = F.pad(node_feat, pad=(self.irreps_channels, self.hidden_irreps.dim - 2 * self.irreps_channels))
        node_pos = data.pos

        edge_index = data.edge_index
        col, row = edge_index

        diff_pos = node_pos[col] - node_pos[row]

        edge_attr = self.edge_embedding(data.edge_attr)
        edge_sh = self.spherical_harmonics(diff_pos)
        
        
        for layer in self.layers:
            node_feat = layer(node_feat, edge_index, edge_attr, edge_sh)
        graph_feat = global_mean_pool(node_feat, data.batch)
        return graph_feat
