import torch
import torch.nn as nn
from typing import Tuple, Union

from data.structures import MaskMatrices
from data.encode import num_atom_features, num_bond_features
from net.utils.components import MLP, GraphIsomorphismNetwork


class ForceModel(nn.Module):
    def __init__(self, use_cuda=False):
        super(ForceModel, self).__init__()
        self.use_cuda = use_cuda
        self.gin = GraphIsomorphismNetwork(
            atom_dim=num_atom_features(),
            bond_dim=num_bond_features(),
            h_dim=256,
            n_layer=4,
            use_cuda=use_cuda
        )
        self.force_mlp = MLP(768, 1, hidden_dims=[256], use_cuda=use_cuda, bias=False)

    def forward(self, atom_ftr: torch.Tensor, bond_ftr: torch.Tensor, pos: torch.Tensor, mask_matrices: MaskMatrices,
                target_pos: torch.Tensor = None
                ) -> Tuple[Union[torch.Tensor, None], torch.Tensor, Union[torch.Tensor, None]]:
        bond_ftr, vew1, vew2 = GraphIsomorphismNetwork.extend_graph(
            bond_ftr=bond_ftr,
            pos=pos,
            vew1=mask_matrices.vertex_edge_w1,
            vew2=mask_matrices.vertex_edge_w2,
            use_cuda=self.use_cuda
        )
        _, he_ftr = self.gin.forward(atom_ftr, bond_ftr, vew1, vew2)
        edge_force_abs = self.force_mlp.forward(he_ftr)
        if target_pos is not None:
            target_edge_distance = GraphIsomorphismNetwork.edge_distances(target_pos, vew1, vew2)
            edge_distance = GraphIsomorphismNetwork.edge_distances(pos, vew1, vew2)
            edge_distance_delta = edge_distance - target_edge_distance
            return None, edge_force_abs, edge_distance_delta
        # 1 -> 2
        edge_vec = vew2.t() @ pos - vew1.t() @ pos
        edge_vec_unit = edge_vec / (torch.norm(edge_vec, dim=1, keepdim=True) + 1e-6)
        edge_force_vec = edge_force_abs * edge_vec_unit
        vertex_force_vec = vew1 @ edge_force_vec - vew2 @ edge_force_vec
        # del bond_ftr, vew1, vew2
        return vertex_force_vec, edge_force_abs, None
