from typing import Dict

import einops
import torch

from openfold.np.residue_constants import RESTYPE_ATOM37_MASK
from proteinfoundation.nn.feature_factory import FeatureFactory
from proteinfoundation.nn.modules.attn_n_transition import MultiheadAttnAndTransition
from proteinfoundation.nn.modules.pair_update import PairReprUpdate
from proteinfoundation.nn.modules.seq_transition_af3 import Transition


def get_atom_mask(device: torch.device = None):
    return torch.from_numpy(RESTYPE_ATOM37_MASK).to(dtype=torch.bool, device=device)


class DecoderTransformer(torch.nn.Module):

    def __init__(self, **kwargs):

        super(DecoderTransformer, self).__init__()
        self.nlayers = kwargs["decoder"]["nlayers"]
        self.token_dim = kwargs["decoder"]["token_dim"]
        self.pair_repr_dim = kwargs["decoder"]["pair_repr_dim"]
        self.update_pair_repr = kwargs["decoder"]["update_pair_repr"]
        self.update_pair_repr_every_n = kwargs["decoder"]["update_pair_repr_every_n"]
        self.use_tri_mult = kwargs["decoder"]["use_tri_mult"]
        self.use_qkln = kwargs["decoder"]["use_qkln"]
        self.abs_coors = kwargs["decoder"].get("abs_coors", True)

        self.init_repr_factory = FeatureFactory(
            feats=kwargs["decoder"]["feats_seq"],
            dim_feats_out=kwargs["decoder"]["token_dim"],
            use_ln_out=False,
            mode="seq",
            **kwargs["decoder"],
        )

        self.cond_factory = FeatureFactory(
            feats=kwargs["decoder"]["feats_cond_seq"],
            dim_feats_out=kwargs["decoder"]["dim_cond"],
            use_ln_out=False,
            mode="seq",
            **kwargs["decoder"],
        )

        self.transition_c_1 = Transition(
            kwargs["decoder"]["dim_cond"], expansion_factor=2
        )
        self.transition_c_2 = Transition(
            kwargs["decoder"]["dim_cond"], expansion_factor=2
        )

        self.pair_rep_factory = FeatureFactory(
            feats=kwargs["decoder"]["feats_pair_repr"],
            dim_feats_out=kwargs["decoder"]["pair_repr_dim"],
            use_ln_out=False,
            mode="pair",
            **kwargs["decoder"],
        )

        self.transformer_layers = torch.nn.ModuleList(
            [
                MultiheadAttnAndTransition(
                    dim_token=self.token_dim,
                    dim_pair=self.pair_repr_dim,
                    nheads=kwargs["decoder"]["nheads"],
                    dim_cond=kwargs["decoder"]["dim_cond"],
                    residual_mha=True,
                    residual_transition=True,
                    parallel_mha_transition=False,
                    use_attn_pair_bias=True,
                    use_qkln=self.use_qkln,
                )
                for _ in range(self.nlayers)
            ]
        )

        if self.update_pair_repr:
            self.pair_update_layers = torch.nn.ModuleList(
                [
                    (
                        PairReprUpdate(
                            token_dim=kwargs["decoder"]["token_dim"],
                            pair_dim=kwargs["decoder"]["pair_repr_dim"],
                            use_tri_mult=self.use_tri_mult,
                        )
                        if i % self.update_pair_repr_every_n == 0
                        else None
                    )
                    for i in range(self.nlayers - 1)
                ]
            )

        self.logit_linear = torch.nn.Sequential(
            torch.nn.LayerNorm(self.token_dim),
            torch.nn.Linear(self.token_dim, 20, bias=False),
        )
        self.struct_linear = torch.nn.Sequential(
            torch.nn.LayerNorm(self.token_dim),
            torch.nn.Linear(self.token_dim, int(37 * 3), bias=False),
        )

    def forward(self, input: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

        ca_coors_nm = input["ca_coors_nm"]
        mask = input["residue_mask"]

        c = self.cond_factory(input)
        c = self.transition_c_2(self.transition_c_1(c, mask), mask)

        seq_f_repr = self.init_repr_factory(input)
        seqs = seq_f_repr * mask[..., None]

        pair_rep = self.pair_rep_factory(input)

        for i in range(self.nlayers):
            seqs = self.transformer_layers[i](seqs, pair_rep, c, mask)

            if self.update_pair_repr:
                if i < self.nlayers - 1:
                    if self.pair_update_layers[i] is not None:
                        pair_rep = self.pair_update_layers[i](seqs, pair_rep, mask)

        logits_out = self.logit_linear(seqs) * mask[..., None]

        coors_flat_nm = self.struct_linear(seqs) * mask[..., None]
        coors_a37_nm = einops.rearrange(
            coors_flat_nm, "b n (a t) -> b n a t", a=37, t=3
        )

        if self.abs_coors:
            coors_a37_nm[..., 1, :] = coors_a37_nm[..., 1, :] * 0.0 + ca_coors_nm
        else:
            coors_a37_nm[..., 1, :] = coors_a37_nm[..., 1, :] * 0.0
            coors_a37_nm = coors_a37_nm + ca_coors_nm[:, :, None, :]

        aatype_max = torch.argmax(logits_out, dim=-1)
        aatype_max = aatype_max * mask

        aa_a37_mask = get_atom_mask(device=logits_out.device)
        atom_mask = aa_a37_mask[aatype_max, :]
        atom_mask = atom_mask * mask[..., None]

        output = {
            "coors_nm": coors_a37_nm,
            "seq_logits": logits_out,
            "residue_mask": mask,
            "aatype_max": aatype_max,
            "atom_mask": atom_mask,
        }
        return output
