from typing import Dict

import torch
import torch.nn as nn

import schnetpack.properties as properties
import schnetpack.nn as snn
from schnetpack.representation.painn import *


class DiffPaiNN(PaiNN):
    def __init__(self, dim_h, **kwargs):
        super().__init__(kwargs)

        # overwrite PaiNN default embedding OP.
        # Use max_z as dim for one hot of atoms or use num_included species in the dataset ???
        # what about the 0 padding (the code 0 is not learnable) (if one hot is defined without including the atom of
        # type 0 then the type 0 atom wont be trainined because one_hot will be (0..0) and input of zeros will return zero
        # or define in the one hot encoder the first index to encode type_Atom=0 if it s set to 1 and others to 0 and make it s weight untrainable)???
        # nn.Embedding(max_z, n_atom_basis, padding_idx=0)
        self.embedding = nn.Linear(dim_h, self.hidden_nf)

    def forward(self, inputs: Dict[str, torch.Tensor]):
        """
        Compute atomic representations/embeddings.

        Args:
            inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors.

        Returns:
            torch.Tensor: atom-wise representation.
            list of torch.Tensor: intermediate atom-wise representations, if
            return_intermediate=True was used.
        """
        # get tensors from input dictionary
        atomic_numbers = inputs[properties.Z]
        r_ij = inputs[properties.Rij]
        idx_i = inputs[properties.idx_i]
        idx_j = inputs[properties.idx_j]
        n_atoms = atomic_numbers.shape[0]

        # compute atom and pair features
        d_ij = torch.norm(r_ij, dim=1, keepdim=True)
        dir_ij = r_ij / d_ij
        phi_ij = self.radial_basis(d_ij)
        fcut = self.cutoff_fn(d_ij)

        filters = self.filter_net(phi_ij) * fcut[..., None]
        if self.share_filters:
            filter_list = [filters] * self.n_interactions
        else:
            filter_list = torch.split(filters, 3 * self.n_atom_basis, dim=-1)

        q = self.embedding(atomic_numbers)[:, None]
        qs = q.shape
        mu = torch.zeros((qs[0], 3, qs[2]), device=q.device)

        for i, (interaction, mixing) in enumerate(zip(self.interactions, self.mixing)):
            q, mu = interaction(q, mu, filter_list[i], dir_ij, idx_i, idx_j, n_atoms)
            q, mu = mixing(q, mu)

        q = q.squeeze(1)

        inputs["scalar_representation"] = q
        inputs["vector_representation"] = mu
        return inputs
