import torch
import torch.nn as nn

from e3nn import o3

from nequip.data import AtomicDataDict
from nequip.nn import GraphModuleMixin

from gto import GTOBasis

class AuxdensityHead(nn.Module):
    def __init__(self, irreps_in: str | o3.Irreps, type_names: list[str], biases: bool = False, auxbasis: str = 'def2-universal-jfit'):
        super().__init__()
        self.auxbasis = GTOBasis.from_basis_name(auxbasis, elements=type_names)
        self.species_list = list(self.auxbasis.per_species_irreps_out.keys())
        self.per_species_modules = nn.ModuleDict({
            species: o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out, biases=biases)
            for species, irreps_out in self.auxbasis.per_species_irreps_out.items()
        })

    def forward(self, atom_features: torch.Tensor, species_indices: dict[str, torch.Tensor]):
        outputs = {}
        for species in self.species_list:
            outputs[species] = self.per_species_modules[species](atom_features[species_indices[species]])
        return outputs


class AuxdensityHeadForNequip(GraphModuleMixin, torch.nn.Module):
    def __init__(
        self,
        type_names: list[str],
        auxbasis: str = 'def2-universal-jfit',
        field: str = AtomicDataDict.NODE_FEATURES_KEY,
        out_field: str | None = None,
        biases: bool = True,
        irreps_in: str | o3.Irreps = None,
    ):
        super().__init__()
        self.field = field
        out_field = out_field if out_field is not None else field
        self.out_field = out_field

        self._init_irreps(
            irreps_in=irreps_in,
            required_irreps_in=[field],
            # we do not init irreps_out here because this module should be the final layer
        )
        self.layer = AuxdensityHead(atom_irreps=irreps_in[field], type_names=type_names, auxbasis=auxbasis, biases=biases)

    def forward(self, data):
        data[self.out_field] = self.layer(data[self.field], data['species_indices'])
        return data
