from typing import Dict

import einops
import torch

from openfold.np.residue_constants import RESTYPE_ATOM37_MASK
from proteinfoundation.nn.feature_factory import FeatureFactory


def get_atom_mask(device: torch.device = None):
    return torch.from_numpy(RESTYPE_ATOM37_MASK).to(dtype=torch.bool, device=device)


class ResidualLayer(torch.nn.Module):

    def __init__(self, dim):
        super(ResidualLayer, self).__init__()
        self.layer_norm = torch.nn.LayerNorm(dim)
        self.linear = torch.nn.Linear(dim, dim)
        self.softplus = torch.nn.Softplus()

    def forward(self, x):
        out = self.layer_norm(x)
        out = self.linear(out)
        out = self.softplus(out)
        return x + out


class DecoderFFLocal(torch.nn.Module):

    def __init__(self, **kwargs):

        super(DecoderFFLocal, self).__init__()
        nlayers = kwargs["decoder"]["nlayers"]
        token_dim = kwargs["decoder"]["token_dim"]

        self.init_repr_factory = FeatureFactory(
            feats=kwargs["decoder"]["feats_seq"],
            dim_feats_out=token_dim,
            use_ln_out=False,
            mode="seq",
            **kwargs["decoder"],
        )

        layers = []
        for _ in range(nlayers):
            layers.append(ResidualLayer(dim=token_dim))
        self.ff_nn = torch.nn.Sequential(*layers)

        self.logit_linear = torch.nn.Sequential(
            torch.nn.LayerNorm(token_dim),
            torch.nn.Linear(token_dim, 20, bias=False),
        )
        self.struct_linear = torch.nn.Sequential(
            torch.nn.LayerNorm(token_dim),
            torch.nn.Linear(token_dim, int(37 * 3), bias=False),
        )

    def forward(self, input: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

        ca_coors_nm = input["ca_coors_nm"]
        mask = input["residue_mask"]

        seq_f_repr = self.init_repr_factory(input)
        seqs = seq_f_repr * mask[..., None]

        seqs = self.ff_nn(seqs) * mask[..., None]

        logits_out = self.logit_linear(seqs) * mask[..., None]

        coors_flat_nm = self.struct_linear(seqs) * mask[..., None]
        coors_a37_nm = einops.rearrange(
            coors_flat_nm, "b n (a t) -> b n a t", a=37, t=3
        )
        coors_a37_nm[..., 1, :] = coors_a37_nm[..., 1, :] * 0.0 + ca_coors_nm

        aatype_max = torch.argmax(logits_out, dim=-1)
        aatype_max = aatype_max * mask

        aa_a37_mask = get_atom_mask(device=logits_out.device)
        atom_mask = aa_a37_mask[aatype_max, :]
        atom_mask = atom_mask * mask[..., None]

        output = {
            "coors_nm": coors_a37_nm,
            "seq_logits": logits_out,
            "residue_mask": mask,
            "aatype_max": aatype_max,
            "atom_mask": atom_mask,
        }
        return output
