"""
2023.03.05 init
"""
from functools import partial

import torch
import torch.nn as nn

#. model
from myopenfold.model.embedders import (
    InputEmbedder,
    RecyclingEmbedder,
    ExtraMSAEmbedder,
)
from myopenfold.model.evoformer import EvoformerStack, ExtraMSAStack
from myopenfold.model.heads import AuxiliaryHeads
from myopenfold.model.structure_module import StructureModule

#. constant
import myopenfold.np.residue_constants as residue_constants

#. utils
from myopenfold.utils.feats import (
    pseudo_beta_fn,
    build_extra_msa_feat,
    atom14_to_atom37,
)
from myopenfold.utils.tensor_utils import (
    add,
    tensor_tree_map,
)


class AlphaFold(nn.Module):
    """
    Alphafold 2.

    Implements Algorithm 2 (but with training).
    """

    def __init__(self, config, log=False):
        """
        Args:
            config:
                A dict-like config object (like the one in config.py)
        """
        super(AlphaFold, self).__init__()

        self.depth = 0
        self.log = log
        
        self.globals = config.globals
        self.config = config.model
        self.template_config = self.config.template
        self.extra_msa_config = self.config.extra_msa

        self.get_evoformer_embedding = self.globals.get_evoformer_embedding
        self.get_all_evoformer_embedding = self.globals.get_all_evoformer_embedding
        self.get_all_structure = self.globals.get_all_structure

        # if self.get_all_evoformer_embedding:
        #     self.evoformer_head = evoformer_head
        #     self.evoformer_loss = evoformer_loss

        print(self.globals)
        print(self.extra_msa_config)

        #. input embedder
        # Main trunk + structure module
        self.input_embedder = InputEmbedder(
            **self.config["input_embedder"],
        )

        #. recycling embedder
        self.recycling_embedder = RecyclingEmbedder(
            **self.config["recycling_embedder"],
        )
       
        #. extra msa embedder
        if(self.extra_msa_config.enabled):
            self.extra_msa_embedder = ExtraMSAEmbedder(
                **self.extra_msa_config["extra_msa_embedder"]
            )
            self.extra_msa_stack = ExtraMSAStack(
                **self.extra_msa_config["extra_msa_stack"]
            )
        
        #. evoformer module
        self.evoformer = EvoformerStack(
            **self.config["evoformer_stack"], depth = self.depth + 1, log = log, get_evoformer_embedding = self.get_evoformer_embedding
        )

        #. structure module
        self.structure_module = StructureModule(
            **self.config["structure_module"], depth = self.depth + 1, log = log
        )

        #. auxiliary prediction: pLDDT, MSA mask, Distogram, isExperimental, TM
        self.aux_heads = AuxiliaryHeads(
            self.config["heads"], depth = self.depth + 1, log = log
        )

    #. main forward method for recycling
    def iteration(self, feats, prevs, _recycle=True):
        #. aatype [*, N_res]
        #. target_feat [*, N_res, C_tf]
        #. residue_index [*, N_res]
        #. msa_feat [*, N_seq, N_res, C_msa]
        #. seq_mask [*, N_res]
        #. msa_mask [*, N_seq, N_res]
        #. pair_mask [*, N_res, N_res]
        #. extra_msa_mask [*, N_extra, N_res]

        # Primary output dictionary
        outputs = {}

        # NOTE why
        #. dtype to cast to
        # This needs to be done manually for DeepSpeed's sake
        dtype = next(self.parameters()).dtype
        for k in feats:
            if(feats[k].dtype == torch.float32):
                feats[k] = feats[k].to(dtype=dtype)

        #. grab some data about the input
        batch_dims = feats["target_feat"].shape[:-2]
        no_batch_dims = len(batch_dims)
        n_res = feats["target_feat"].shape[-2]
        n_seq = feats["msa_feat"].shape[-3]
        device = feats["target_feat"].device
        
        #. inplace safe if only grad disabled
        # Controls whether the model uses in-place operations throughout
        # The dual condition accounts for activation checkpoints
        inplace_safe = not (self.training or torch.is_grad_enabled())

        #. seq_mask [*, N_res]
        #. pair_mask [*, N_res, N_res]
        #. msa_mask [*, N_seq, N_res]
        # Prep some features
        seq_mask = feats["seq_mask"]
        pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
        msa_mask = feats["msa_mask"]
        
        ################################### Input Embedding ###################################
        ## Initialize the MSA and pair representations
        # NOTE check the exact operation of this
        #. m [*, N_seq, N_res, C_m]
        #. z [*, N_res, N_res, C_z]
        m, z = self.input_embedder(
            feats["target_feat"],
            feats["residue_index"],
            feats["msa_feat"],
            inplace_safe=inplace_safe,
        )
        # torch.cuda.empty_cache()  # NOTE try to free as much memory as possible
        # print("After input embedding:", torch.cuda.memory_summary())  # DEBUG

        ################################### Recycling Embedding ###################################
        #. release from the list to allow memory to be freed later
        # Unpack the recycling embeddings. Removing them from the list allows 
        # them to be freed further down in this function, saving memory
        m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])

        #. If none, set all to zero tensor
        # Initialize the recycling embeddings, if needs be 
        if None in [m_1_prev, z_prev, x_prev]:
            #. m_1_prev [*, N_res, C_m]
            m_1_prev = m.new_zeros(
                (*batch_dims, n_res, self.config.input_embedder.c_m),
                requires_grad=False,
            )

            #. z_prev [*, N_res, N_res, C_z]
            z_prev = z.new_zeros(
                (*batch_dims, n_res, n_res, self.config.input_embedder.c_z),
                requires_grad=False,
            )

            #. [*, N_res, 37, 3]
            x_prev = z.new_zeros(
                (*batch_dims, n_res, residue_constants.atom_type_num, 3),
                requires_grad=False,
            )
        
        # NOTE need some check
        x_prev = pseudo_beta_fn(
            feats["aatype"], x_prev, None
        ).to(dtype=z.dtype)

        #. embed output from previous recycling
        #. m_1_prev_emb: [*, N_res, C_m]
        #. z_prev_emb: [*, N_res, N_res, C_z]
        m_1_prev_emb, z_prev_emb = self.recycling_embedder(
            m_1_prev,
            z_prev,
            x_prev,
            inplace_safe=inplace_safe,
        )
        #print("m_1_prev_emb", m_1_prev_emb)  # DEBUG

        # add the previous embedding to the current
        # NOTE is this oper inplace?
        # NOTE need to check the gradient flow
        #. [*, N_seq, N_res, C_m]
        m[..., 0, :, :] += m_1_prev_emb

        # NOTE this is much more expected, where inplace or not depends on the 'inplace_safe' flag
        #. [*, N_res, N_res, C_z]
        z = add(z, z_prev_emb, inplace=inplace_safe)
        #print("z after add prev", z)  # DEBUG

        # Deletions like these become significant for inference with large N,
        # where they free unused tensors and remove references to others such
        # that they can be offloaded later
        del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb

        # torch.cuda.empty_cache()  # NOTE try to free as much memory as possible
        # print("After recycling embedding:", torch.cuda.memory_summary())  # DEBUG

        ################################### Extra MSA Embedding ###################################
        #. if extra MSA features enabled, pack all required features in a dict
        #.   and embed
        # Embed extra MSA features + merge with pairwise embeddings
        if self.config.extra_msa.enabled:
            # [*, S_e, N, C_e]
            a = self.extra_msa_embedder(build_extra_msa_feat(feats))
            #print("a after extra_msa_embedder", a)  # DEBUG

            # [*, N, N, C_z]
            z = self.extra_msa_stack(
                a, z,
                msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
                pair_mask=pair_mask.to(dtype=m.dtype),
                inplace_safe=inplace_safe,
                _mask_trans=self.config._mask_trans,
            )
        #print("z after extra msa", z)  # DEBUG

        # torch.cuda.empty_cache()  # NOTE try to free as much memory as possible
        # print("After extra MSA embedding:", torch.cuda.memory_summary())  # DEBUG

        ################################### Evoformer ###################################
        #. evoformer module
        # m: [*, S, N, C_m]
        # z: [*, N, N, C_z]
        # s: [*, N, C_s]
        # NOTE 23.04.17 evoformer embedding hack
        if not self.get_evoformer_embedding:        
            m, z, s = self.evoformer(
                m,
                z,
                msa_mask=msa_mask.to(dtype=m.dtype),
                pair_mask=pair_mask.to(dtype=z.dtype),
                inplace_safe=inplace_safe,
                _mask_trans=self.config._mask_trans,
            )
        else:
            m, z, s, embedding_list = self.evoformer(
                m,
                z,
                msa_mask=msa_mask.to(dtype=m.dtype),
                pair_mask=pair_mask.to(dtype=z.dtype),
                inplace_safe=inplace_safe,
                _mask_trans=self.config._mask_trans,
            )
            # assert torch.allclose(z, embedding_list[-1][1])

        if self.log: print('After evoformer..')
        if self.log: print('m shape: ', m.shape)

        if not self.globals.fix_before_structure:
            outputs["msa"] = m[..., :n_seq, :, :]
            outputs["pair"] = z
            outputs["single"] = s
        else:
            print("Fixing weight before structure module, detach m, z, s")  # DEBUG
            outputs["msa"] = m[..., :n_seq, :, :].detach()
            outputs["pair"] = z.detach()
            outputs["single"] = s.detach()

        del z

        # torch.cuda.empty_cache()  # NOTE try to free as much memory as possible
        # print("After evoformer embedding:", torch.cuda.memory_summary())  # DEBUG

        ################################### Structure Module ###################################
        #. structure module
        # Predict 3D structure
        outputs["sm"] = self.structure_module(
            outputs,
            feats["aatype"],
            mask=feats["seq_mask"].to(dtype=s.dtype),
            inplace_safe=inplace_safe,
        )
        outputs["final_atom_positions"] = atom14_to_atom37(
            outputs["sm"]["positions"][-1], feats
        )
        outputs["final_atom_mask"] = feats["atom37_atom_exists"]
        outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]

        # NOTE 23.04.17 evoformer embedding hack
        if self.get_evoformer_embedding:
            outputs["evoformer_embedding"] = embedding_list

        #. m_1 embedding is the first row of sequence embedding

        # Save embeddings for use during the next recycling iteration

        #. m_1_prev [*, N_res, C_m]
        m_1_prev = m[..., 0, :, :]

        #. z_prev [*, N_res, N_res, C_z]
        z_prev = outputs["pair"]

        #. x_prev [*, N_res, 37, 3]
        x_prev = outputs["final_atom_positions"]

        # torch.cuda.empty_cache()  # NOTE try to free as much memory as possible
        # print("After structure module:", torch.cuda.memory_summary())  # DEBUG

        return outputs, m_1_prev, z_prev, x_prev

    def forward(self, batch):
        """
        Args:
            batch:
                Dictionary of arguments outlined in Algorithm 2. Keys must
                include the official names of the features in the
                supplement subsection 1.2.9.

                The final dimension of each input must have length equal to
                the number of recycling iterations.

                Features (without the recycling dimension):

                    "aatype" ([*, N_res]):
                        Contrary to the supplement, this tensor of residue
                        indices is not one-hot.
                    "target_feat" ([*, N_res, C_tf])
                        One-hot encoding of the target sequence. C_tf is
                        config.model.input_embedder.tf_dim.
                    "residue_index" ([*, N_res])
                        Tensor whose final dimension consists of
                        consecutive indices from 0 to N_res.
                    "msa_feat" ([*, N_seq, N_res, C_msa])
                        MSA features, constructed as in the supplement.
                        C_msa is config.model.input_embedder.msa_dim.
                    "seq_mask" ([*, N_res])
                        1-D sequence mask
                    "msa_mask" ([*, N_seq, N_res])
                        MSA mask
                    "pair_mask" ([*, N_res, N_res])
                        2-D pair mask
                    "extra_msa_mask" ([*, N_extra, N_res])
                        Extra MSA mask
        """
        if self.log:  # DEBUG
            for k, v in batch.items():  # DEBUG
                print(k, v.shape)  # DEBUG
            
        # Initialize recycling embeddings
        m_1_prev, z_prev, x_prev = None, None, None
        prevs = [m_1_prev, z_prev, x_prev]

        is_grad_enabled = torch.is_grad_enabled()

        # NOTE add to get result from different iterations
        if self.get_all_evoformer_embedding or self.get_all_structure:
            addtional_outputs = {}

        # Main recycling loop
        num_iters = batch["aatype"].shape[-1]
        for cycle_no in range(num_iters): 
            # Select the features for the current recycling cycle
            fetch_cur_batch = lambda t: t[..., cycle_no]
            feats = tensor_tree_map(fetch_cur_batch, batch)

            #. only grad for the last iteration is backpropagated?
            # Enable grad iff we're training and it's the final recycling layer
            is_final_iter = cycle_no == (num_iters - 1)
            with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
                if is_final_iter:
                    # NOTE ensure grad is enabled for the final iteration when amp is used
                    #.https://github.com/pytorch/pytorch/issues/65766#issuecomment-932511164
                    # Sidestep AMP bug (PyTorch issue #65766)
                    if torch.is_autocast_enabled():
                        torch.clear_autocast_cache()

                # Run the next iteration of the model
                outputs, m_1_prev, z_prev, x_prev = self.iteration(
                    feats,
                    prevs,
                    _recycle=(num_iters > 1)
                )

                # NOTE add to get result from different iterations
                if self.get_all_evoformer_embedding:
                    # batch = tensor_tree_map(lambda t: t[..., -1], feats)
                    distograme_logits =  self.evoformer_head(outputs['evoformer_embedding'])
                    loss_cum, loss_dict = self.evoformer_loss(distograme_logits, feats, _return_breakdown=True)
                    distograme_logits = [distograme_logits[i] for i in self.evoformer_layers]
                    addtional_outputs["distogram_logits_recyc_%d" % cycle_no] = tensor_tree_map(lambda t: t.detach().clone().cpu().half().numpy(), distograme_logits)
                    addtional_outputs["distogram_loss_recyc_%d" % cycle_no] = loss_dict
                    del distograme_logits, loss_cum, loss_dict
                if self.get_all_structure:
                    addtional_outputs["structure_recyc_%d" % cycle_no] = outputs["final_atom_positions"].detach().clone().cpu().half().numpy()

                #. if not the final iter, use a list to store prevs
                if(not is_final_iter):
                    del outputs
                    prevs = [m_1_prev, z_prev, x_prev]
                    del m_1_prev, z_prev, x_prev
                torch.cuda.empty_cache()  # NOTE try to free as much memory as possible

        #. auxiliary prediction
        # Run auxiliary heads
        outputs.update(self.aux_heads(outputs))

        # NOTE add to get result from different iterations
        if self.get_all_evoformer_embedding or self.get_all_structure:
            outputs['addtional'] = addtional_outputs

        return outputs
