"""
2023.03.05 init
"""
import math
import sys
import torch
import torch.nn as nn
from typing import Tuple, Sequence, Optional
from functools import partial

#. model
from myopenfold.model.primitives import Linear, LayerNorm
from myopenfold.model.dropout import DropoutRowwise   # DropoutColumnwise
from myopenfold.model.msa import (
    MSARowAttentionWithPairBias,
    MSAColumnAttention,
    MSAColumnGlobalAttention,
)
from myopenfold.model.outer_product_mean import OuterProductMean
from myopenfold.model.pair_transition import PairTransition
from myopenfold.model.triangular_attention import TriangleAttention
from myopenfold.model.triangular_multiplicative_update import (
    TriangleMultiplicationOutgoing,
    TriangleMultiplicationIncoming,
)

#. utils
from myopenfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from myopenfold.utils.tensor_utils import add


class MSATransition(nn.Module):
    """
    Feed-forward network applied to MSA activations after attention.

    Implements Algorithm 9
    """
    def __init__(self, c_m, n, depth=0, ind=0, log=False):
        """
        Args:
            c_m:
                MSA channel dimension
            n:
                Factor multiplied to c_m to obtain the hidden channel
                dimension
            .depth:
                Depth of this module in the whold model
            .ind:
                Index of this block in the stack
        """
        super(MSATransition, self).__init__()

        self.depth = depth
        self.ind = ind
        self.log = log

        self.c_m = c_m
        self.n = n

        self.layer_norm = LayerNorm(self.c_m)
        self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
        self.relu = nn.ReLU()
        self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")

    def _transition(self, m, mask):
        m = self.layer_norm(m)
        m = self.linear_1(m)
        m = self.relu(m)
        m = self.linear_2(m) * mask
        return m

    def forward(
        self,
        m: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            m:
                [*, N_seq, N_res, C_m] MSA activation
            mask:
                [*, N_seq, N_res] MSA mask
        Returns:
            m:
                [*, N_seq, N_res, C_m] MSA activation update
        """
        # NOTE this mask does't seem to be mandatory, as long as mask is always considered in other opers
        # DISCREPANCY: DeepMind forgets to apply the MSA mask here.
        if mask is None:
            mask = m.new_ones(m.shape[:-1])

        mask = mask.unsqueeze(-1)

        m = self._transition(m, mask)

        return m


class EvoformerBlockCore(nn.Module):
    def __init__(
        self,
        c_m: int,
        c_z: int,
        c_hidden_opm: int,
        c_hidden_mul: int,
        c_hidden_pair_att: int,
        no_heads_msa: int,
        no_heads_pair: int,
        transition_n: int,
        pair_dropout: float,
        inf: float,
        eps: float,
        _is_extra_msa_stack: bool = False,
        depth=0,
        ind=0,
        log=False,
        no_extra_msa=False,
        no_triangular_attention = False,
        no_triangular_multiplication = False,
    ):
        super(EvoformerBlockCore, self).__init__()

        self.depth = depth
        self.ind = ind
        self.log = log
        self.no_extra_msa = no_extra_msa
        self._is_extra_msa_stack = _is_extra_msa_stack

        self.no_triangular_attention = no_triangular_attention
        self.no_triangular_multiplication = no_triangular_multiplication

        #. MSA Trasitioon
        self.msa_transition = MSATransition(
            c_m=c_m,
            n=transition_n,
            depth=depth + 1,
            ind = ind,
            log=log,
        )

        #. Outer Product Mean
        self.outer_product_mean = OuterProductMean(
            c_m,
            c_z,
            c_hidden_opm,
            depth=depth + 1,
            ind = ind,
            log=log,
        )

        #. Triangle Multiplication Outgoing
        self.tri_mul_out = TriangleMultiplicationOutgoing(
            c_z,
            c_hidden_mul,
            depth=depth + 1,
            ind = ind,
            log=log,
            _is_extra_msa_stack = self._is_extra_msa_stack,
        )

        #. Triangle Multiplication Incoming
        self.tri_mul_in = TriangleMultiplicationIncoming(
            c_z,
            c_hidden_mul,
            depth=depth + 1,
            ind = ind,
            log=log,
        )

        #. Triangle Attention with Starting Node
        self.tri_att_start = TriangleAttention(
            c_z,
            c_hidden_pair_att,
            no_heads_pair,
            inf=inf,
            depth=depth + 1,
            ind = ind,
            log=log,
        )

        #. Triangle Attention with Ending Node
        self.tri_att_end = TriangleAttention(
            c_z,
            c_hidden_pair_att,
            no_heads_pair,
            inf=inf,
            depth=depth + 1,
            ind = ind,
            log=log,
        )

        #. Pair Transition
        self.pair_transition = PairTransition(
            c_z,
            transition_n,
            depth=depth + 1,
            ind = ind,
            log=log,
        )

        #. Dropout Row-Wise
        # NOTE Dropout Col-Wise is implemented by transposing the input
        self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)

    def forward(self,
        input_tensors: Sequence[torch.Tensor],
        msa_mask: torch.Tensor,
        pair_mask: torch.Tensor,
        inplace_safe: bool = False,
        _mask_trans: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        #. input_tensors[0] [*, N_seq, N_res, C_m] MSA activation
        #. input_tensors[1] [*, N_res, N_res, C_z] Pair activation
        #. msa_mask [*, N_seq, N_res] MSA mask
        #. pair_mask [*, N_res, N_res] Pair mask
        if self.log: print("\t" * self.depth + 'In EvoformerBlockCore: init, ind', self.ind)  # DEBUG
        if self.ind == 0 and self.log: print("\t" * self.depth + "input_tensors: ", input_tensors[0].shape, input_tensors[1].shape)  # DEBUG
        if self.ind == 0 and self.log: print("\t" * self.depth + "msa_mask: ", msa_mask.shape)  # DEBUG
        if self.ind == 0 and self.log: print("\t" * self.depth + "pair_mask: ", pair_mask.shape)  # DEBUG

        #. MSA and Pair transition are only masked if _mask_trans is True, which is not the case of AF original implementation.
        # DeepMind doesn't mask these transitions in the source, so _mask_trans
        # should be disabled to better approximate the exact activations of
        # the original.
        msa_trans_mask = msa_mask if _mask_trans else None
        pair_trans_mask = pair_mask if _mask_trans else None

        m, z = input_tensors


        if self.no_extra_msa:
            pass
            # print("In Core: no extra msa mode detected, bypassing extra MSA operation..")  # DEBUG
        else:
            #. MSA Transition
            m = add(
                m,
                self.msa_transition(
                    m, mask=msa_trans_mask
                ),
                inplace=inplace_safe,
            )
            # print('After msa_transition, m', torch.isnan(m).any())  # DEBUG

            #. Outer Product Mean
            opm = self.outer_product_mean(
                m, mask=msa_mask, inplace_safe=inplace_safe
            )
            #print(self.ind, "opm", opm)  # DEBUG
            # print('After outer_product_mean, opm', torch.isnan(opm).any())  # DEBUG

            #. here starts the Pair Stack

            z = add(z, opm, inplace=inplace_safe)
            #print(self.ind, "z after opm", z)  # DEBUG
            # print('After outer_product_mean, z', torch.isnan(z).any())  # DEBUG
            del opm

        if self.no_triangular_multiplication:
            # print("In Core: no triangular multiplication mode detected, bypassing..")  # DEBUG
            pass
        else:
            # NOTE if inplace_safe is True, then z will be updated inside the triangle multiplication
            #. Triangle Multiplication Outgoing
            tmu_update = self.tri_mul_out(
                z,
                mask=pair_mask,
                inplace_safe=inplace_safe,
                _add_with_inplace=True,
            )
            # print('After tri_mul_out, tmu_update', torch.isnan(tmu_update).any())  # DEBUG
            #print(self.ind, "tmu_update outgoing", tmu_update)  # DEBUG
            if(not inplace_safe):
                z = z + self.ps_dropout_row_layer(tmu_update)
            else:
                z = tmu_update
            
            del tmu_update
            # NOTE if inplace_safe is True, then z will be updated inside the triangle multiplication
            #. Triangle Multiplication Incoming
            tmu_update = self.tri_mul_in(
                z,
                mask=pair_mask,
                inplace_safe=inplace_safe,
                _add_with_inplace=True,
            )
            # print('After tri_mul_in, tmu_update', torch.isnan(tmu_update).any())  # DEBUG
            #print(self.ind, "tmu_update incoming", tmu_update)  # DEBUG
            if(not inplace_safe):
                z = z + self.ps_dropout_row_layer(tmu_update)
            else:
                z = tmu_update
        
            del tmu_update

        if self.no_triangular_attention:
            # print("In Core: no triangular attention mode detected, bypassing..")  # DEBUG
            pass
        else:
            #. Triangle Gated Attention around Starting Node
            z = add(z, 
                self.ps_dropout_row_layer(
                    self.tri_att_start(
                        z, 
                        mask=pair_mask, 
                        use_memory_efficient_kernel=False,
                        inplace_safe=inplace_safe,
                    )
                ),
                inplace=inplace_safe,
            )
            #print(self.ind, "z after tri_att_start", z)  # DEBUG

            #. Triangle Gated Attention around Ending Node
            #. the transpose is done before passing to the Triangular Gated Attention Module,
            #.  so we need not do that again inside.

            z = z.transpose(-2, -3)
            # NOTE why is this oper inplace?
            if(inplace_safe):
                input_tensors[1] = z.contiguous()
                z = input_tensors[1]
            
            z = add(z,
                self.ps_dropout_row_layer(
                    self.tri_att_end(
                        z,
                        mask=pair_mask.transpose(-1, -2),
                        use_memory_efficient_kernel=False,
                        inplace_safe=inplace_safe,
                    )
                ),
                inplace=inplace_safe,
            )
            #print(self.ind, "z after tri_att_end", z)  # DEBUG

            z = z.transpose(-2, -3)
            if(inplace_safe):
                input_tensors[1] = z.contiguous()
                z = input_tensors[1]

        #. Pair Transition
        z = add(z,
            self.pair_transition(
                z, mask=pair_trans_mask,
            ),
            inplace=inplace_safe,
        )
        #print(self.ind, "z after pair_transition", z)  # DEBUG

        return m, z


class EvoformerBlock(nn.Module):
    def __init__(self,
        c_m: int,
        c_z: int,
        c_hidden_msa_att: int,
        c_hidden_opm: int,
        c_hidden_mul: int,
        c_hidden_pair_att: int,
        no_heads_msa: int,
        no_heads_pair: int,
        transition_n: int,
        msa_dropout: float,
        pair_dropout: float,
        inf: float,
        eps: float,
        depth: int = 0,
        ind=0,
        log: bool = False,
        no_triangular_attention = False,
        no_triangular_multiplication = False,
    ):
        super(EvoformerBlock, self).__init__()

        self.depth = depth
        self.ind = ind
        self.log = log
#        print("\t" * self.depth + "In EvoformerBlock: depth = %d, ind = %d" % (depth, ind))  # DEBUG

        self.no_triangular_attention = no_triangular_attention
        self.no_triangular_multiplication = no_triangular_multiplication

        #. MSA Row Attention with Pair Bias
        self.msa_att_row = MSARowAttentionWithPairBias(
            c_m=c_m,
            c_z=c_z,
            c_hidden=c_hidden_msa_att,
            no_heads=no_heads_msa,
            inf=inf,
            depth=depth + 1,
            ind=ind,
            log=log,
        )

        #. MSA Column Attention
        self.msa_att_col = MSAColumnAttention(
            c_m,
            c_hidden_msa_att,
            no_heads_msa,
            inf=inf,
            depth=depth + 1,
            ind=ind,
            log=log,
        )

        #. Dropout Row-Wise
        self.msa_dropout_layer = DropoutRowwise(msa_dropout)

        self.core = EvoformerBlockCore(
            c_m=c_m,
            c_z=c_z,
            c_hidden_opm=c_hidden_opm,
            c_hidden_mul=c_hidden_mul,
            c_hidden_pair_att=c_hidden_pair_att,
            no_heads_msa=no_heads_msa,
            no_heads_pair=no_heads_pair,
            transition_n=transition_n,
            pair_dropout=pair_dropout,
            inf=inf,
            eps=eps,
            depth=depth + 1,
            ind=ind,
            log=log,
            no_triangular_attention=no_triangular_attention,
            no_triangular_multiplication=no_triangular_multiplication,
        )

    def forward(self,
        m: Optional[torch.Tensor],
        z: Optional[torch.Tensor],
        msa_mask: torch.Tensor,
        pair_mask: torch.Tensor,
        inplace_safe: bool = False,
        _mask_trans: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        #. m [*, N_seq, N_res, c_m]
        #. z [*, N_seq, N_res, c_z]
        #. msa_mask [*, N_seq, N_res]
        #. pair_mask [*, N_seq, N_res, N_res]
        if self.ind == 0 and self.log: print('\t' * self.depth, 'In EvoformerBlock: init')  # DEBUG
        if self.ind == 0 and self.log: print('\t' * self.depth, 'm', m.shape)  # DEBUG
        if self.ind == 0 and self.log: print('\t' * self.depth, 'z', z.shape)  # DEBUG
        if self.ind == 0 and self.log: print('\t' * self.depth, 'msa_mask', msa_mask.shape)  # DEBUG
        if self.ind == 0 and self.log: print('\t' * self.depth, 'pair_mask', pair_mask.shape)  # DEBUG
        if self.ind == 0 and self.log: print('\t' * self.depth, 'inplace_safe', inplace_safe)  # DEBUG
        
        input_tensors = [m, z]

        m, z = input_tensors

        #. MSA Row-Wise Gated Self-Attention with Pair Bias
        m = add(m, 
            self.msa_dropout_layer(
                self.msa_att_row(
                    m, 
                    z=z, 
                    mask=msa_mask, 
                    use_memory_efficient_kernel=False,
                )
            ),
            inplace=inplace_safe,
        )
        if self.ind == 0 and self.log: print('\t' * self.depth + 'In EvoformerBlock: after msa_att_row')  # DEBUG
        
        #. MSA Column-Wise Gated Self-Attention
        m = add(m, 
            self.msa_att_col(
                m, 
                mask=msa_mask, 
            ),
            inplace=inplace_safe,
        )
        if self.ind == 0 and self.log: print('\t' * self.depth + 'In EvoformerBlock: after msa_att_col')  # DEBUG

        if(not inplace_safe):
            input_tensors = [m, input_tensors[1]]
        
        del m, z

        #. Core : All other operations in the Evoformer Block
        m, z = self.core(
            input_tensors, 
            msa_mask=msa_mask, 
            pair_mask=pair_mask, 
            inplace_safe=inplace_safe,
            _mask_trans=_mask_trans,
        )
        if self.ind == 0 and self.log: print('\t' * self.depth + 'In EvoformerBlock: after core')  # DEBUG


        return m, z


class ExtraMSABlock(nn.Module):
    """ 
        Almost identical to the standard EvoformerBlock, except in that the
        ExtraMSABlock uses GlobalAttention for MSA column attention and
        requires more fine-grained control over checkpointing. Separated from
        its twin to preserve the TorchScript-ability of the latter.
    """
    def __init__(self,
        c_m: int,
        c_z: int,
        c_hidden_msa_att: int,
        c_hidden_opm: int,
        c_hidden_mul: int,
        c_hidden_pair_att: int,
        no_heads_msa: int,
        no_heads_pair: int,
        transition_n: int,
        msa_dropout: float,
        pair_dropout: float,
        inf: float,
        eps: float,
        ckpt: bool,
        no_extra_msa = False,
    ):
        super(ExtraMSABlock, self).__init__()
        
        self.ckpt = ckpt
        self.no_extra_msa = no_extra_msa

        self.msa_att_row = MSARowAttentionWithPairBias(
            c_m=c_m,
            c_z=c_z,
            c_hidden=c_hidden_msa_att,
            no_heads=no_heads_msa,
            inf=inf,
        )

        self.msa_att_col = MSAColumnGlobalAttention(
            c_in=c_m,
            c_hidden=c_hidden_msa_att,
            no_heads=no_heads_msa,
            inf=inf,
            eps=eps,
        )

        self.msa_dropout_layer = DropoutRowwise(msa_dropout)

        self.core = EvoformerBlockCore(
            c_m=c_m,
            c_z=c_z,
            c_hidden_opm=c_hidden_opm,
            c_hidden_mul=c_hidden_mul,
            c_hidden_pair_att=c_hidden_pair_att,
            no_heads_msa=no_heads_msa,
            no_heads_pair=no_heads_pair,
            transition_n=transition_n,
            pair_dropout=pair_dropout,
            inf=inf,
            eps=eps,
            no_extra_msa=no_extra_msa,
            _is_extra_msa_stack=True,
        )

    def forward(self,
        m: Optional[torch.Tensor],
        z: Optional[torch.Tensor],
        msa_mask: torch.Tensor,
        pair_mask: torch.Tensor,
        inplace_safe: bool = False,
        _mask_trans: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:  
        input_tensors = [m, z]

        m, z = input_tensors
        #print("Input tensors: ", input_tensors)  # DEBUG

        if self.no_extra_msa:
            True
            # print("In ExtraMSABlock: no extra msa mode detected, bypassing extra MSA operation..")  # DEBUG
        else:
            m = add(m, 
                self.msa_dropout_layer(
                    self.msa_att_row(
                        m.clone() if torch.is_grad_enabled() else m, 
                        z=z.clone() if torch.is_grad_enabled() else z, 
                        mask=msa_mask, 
                        use_memory_efficient_kernel=True,
                        # _checkpoint_chunks=
                        #     self.ckpt if torch.is_grad_enabled() else False,
                    )
                ),
                inplace=inplace_safe,
            )
            # print('After msa_dropout_layer', torch.isnan(m).any())  # DEBUG
            #print("m after msa_att_row", m)  # DEBUG

        if(not inplace_safe):
            input_tensors = [m, z]

        del m, z

        def fn(input_tensors): 
            m = add(input_tensors[0], 
                self.msa_att_col(
                    input_tensors[0], 
                    mask=msa_mask, 
                ),
                inplace=inplace_safe,
            )

            if(not inplace_safe):
                input_tensors = [m, input_tensors[1]]

            del m

            m, z = self.core(
                input_tensors, 
                msa_mask=msa_mask, 
                pair_mask=pair_mask, 
                inplace_safe=inplace_safe,
                _mask_trans=_mask_trans,
            )
            
            return m, z

        if(torch.is_grad_enabled() and self.ckpt):
            checkpoint_fn = get_checkpoint_fn()
            m, z = checkpoint_fn(fn, input_tensors)
        else:
            m, z = fn(input_tensors)
            #print("m and z after ExtraMSA core", m, z)  # DEBUG

        return m, z


class EvoformerStack(nn.Module):
    """
    Main Evoformer trunk.

    Implements Algorithm 6.
    """

    def __init__(
        self,
        c_m: int,
        c_z: int,
        c_hidden_msa_att: int,
        c_hidden_opm: int,
        c_hidden_mul: int,
        c_hidden_pair_att: int,
        c_s: int,
        no_heads_msa: int,
        no_heads_pair: int,
        no_blocks: int,
        transition_n: int,
        msa_dropout: float,
        pair_dropout: float,
        blocks_per_ckpt: int,
        inf: float,
        eps: float,
        clear_cache_between_blocks: bool = False, 
        depth = 0,
        log=False,
        no_triangular_attention = False,
        no_triangular_multiplication = False,
        get_evoformer_embedding = False,
        **kwargs,
    ):
        """
        Args:
            c_m:
                MSA channel dimension
            c_z:
                Pair channel dimension
            c_hidden_msa_att:
                Hidden dimension in MSA attention
            c_hidden_opm:
                Hidden dimension in outer product mean module
            c_hidden_mul:
                Hidden dimension in multiplicative updates
            c_hidden_pair_att:
                Hidden dimension in triangular attention
            c_s:
                Channel dimension of the output "single" embedding
            no_heads_msa:
                Number of heads used for MSA attention
            no_heads_pair:
                Number of heads used for pair attention
            no_blocks:
                Number of Evoformer blocks in the stack
            transition_n:
                Factor by which to multiply c_m to obtain the MSATransition
                hidden dimension
            msa_dropout:
                Dropout rate for MSA activations
            pair_dropout:
                Dropout used for pair activations
            blocks_per_ckpt:
                Number of Evoformer blocks in each activation checkpoint
            clear_cache_between_blocks:
                Whether to clear CUDA's GPU memory cache between blocks of the
                stack. Slows down each block but can reduce fragmentation
            tune_chunk_size:
                Whether to dynamically tune the module's chunk size
        """
        super(EvoformerStack, self).__init__()

        self.depth = depth
        self.log = log
#         print("\t" * self.depth + "In EvoformerStack, depth = ", self.depth)  # DEBUG

        self.no_triangular_attention = no_triangular_attention
        self.no_triangular_multiplication = no_triangular_multiplication
        self.get_evoformer_embedding = get_evoformer_embedding

        self.blocks_per_ckpt = blocks_per_ckpt
        self.clear_cache_between_blocks = clear_cache_between_blocks

        self.blocks = nn.ModuleList()

        for i, _ in enumerate(range(no_blocks)):
            block = EvoformerBlock(
                c_m=c_m,
                c_z=c_z,
                c_hidden_msa_att=c_hidden_msa_att,
                c_hidden_opm=c_hidden_opm,
                c_hidden_mul=c_hidden_mul,
                c_hidden_pair_att=c_hidden_pair_att,
                no_heads_msa=no_heads_msa,
                no_heads_pair=no_heads_pair,
                transition_n=transition_n,
                msa_dropout=msa_dropout,
                pair_dropout=pair_dropout,
                inf=inf,
                eps=eps,
                depth=self.depth+1,
                ind=i,
                log=self.log,
                no_triangular_attention=self.no_triangular_attention,
                no_triangular_multiplication=self.no_triangular_multiplication,
            )
            self.blocks.append(block)

        self.linear = Linear(c_m, c_s)

        # self.tune_chunk_size = tune_chunk_size
        # self.chunk_size_tuner = None
        # if(tune_chunk_size):
        #     self.chunk_size_tuner = ChunkSizeTuner()

    def _prep_blocks(self, 
        m: torch.Tensor, 
        z: torch.Tensor, 
        msa_mask: Optional[torch.Tensor],
        pair_mask: Optional[torch.Tensor],
        inplace_safe: bool,
        _mask_trans: bool,
    ):
        """Mene.
        Prepares the blocks for forward pass. Each block is wrapped as a partial forward function.
        All block are then wrapped in a list.
        """
        blocks = [
            partial(
                b,
                msa_mask=msa_mask,
                pair_mask=pair_mask,
                inplace_safe=inplace_safe,
                _mask_trans=_mask_trans,
            )
            for b in self.blocks
        ]
        
        if(self.clear_cache_between_blocks):
            def block_with_cache_clear(block, *args, **kwargs):
                torch.cuda.empty_cache()
                return block(*args, **kwargs)

            blocks = [partial(block_with_cache_clear, b) for b in blocks]

        return blocks

    def forward(self,
        m: torch.Tensor,
        z: torch.Tensor,
        msa_mask: torch.Tensor,
        pair_mask: torch.Tensor,
        inplace_safe: bool = False,
        _mask_trans: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            m:
                [*, N_seq, N_res, C_m] MSA embedding
            z:
                [*, N_res, N_res, C_z] pair embedding
            msa_mask:
                [*, N_seq, N_res] MSA mask
            pair_mask:
                [*, N_res, N_res] pair mask
        Returns:
            m:
                [*, N_seq, N_res, C_m] MSA embedding
            z:
                [*, N_res, N_res, C_z] pair embedding
            s:
                [*, N_res, C_s] single embedding (or None if extra MSA stack)
        """
        if self.log: print('\t' * self.depth + 'In EvoformerStack: init')  # DEBUG
        if self.log: print('\t' * self.depth + 'm', m.shape)  # DEBUG
        if self.log: print('\t' * self.depth + 'z', z.shape)  # DEBUG
        if self.log: print('\t' * self.depth + 'msa_mask', msa_mask.shape)  # DEBUG
        if self.log: print('\t' * self.depth + 'pair_mask', pair_mask.shape)  # DEBUG
        if self.log: print('\t' * self.depth + 'inplace_safe', inplace_safe)  # DEBUG

        #. prepare the partial functions for each block
        blocks = self._prep_blocks(
            m=m,
            z=z,
            msa_mask=msa_mask,
            pair_mask=pair_mask,
            inplace_safe=inplace_safe,
            _mask_trans=_mask_trans,    # NOTE what is this? solved
        )

        # NOTE need further investigation. solved
        #. if only grad is disabled, we can set blocks_per_ckpt to None
        blocks_per_ckpt = self.blocks_per_ckpt
        if(not torch.is_grad_enabled()):
            blocks_per_ckpt = None
        
        # NOTE 23.04.17 add this hack to get the embedding of each block
        if not torch.is_grad_enabled() and self.get_evoformer_embedding:
            # print('In hack mode to get the embedding of each block..')
            embedding_list = []
            for i, block in enumerate(blocks):
                m, z = block(m, z)
                embedding_list.append((m.detach(), z.detach()))
                torch.cuda.empty_cache()
        else:
            #. forward pass through the blocks with checkpointing
            m, z = checkpoint_blocks(
                blocks,
                args=(m, z),
                blocks_per_ckpt=blocks_per_ckpt,
            )
        if self.log: print('\t' * self.depth + 'In EvoformerStack: after evoformer')  # DEBUG
        if self.log: print('\t' * self.depth + 'm', m.shape)  # DEBUG
        if self.log: print('\t' * self.depth + 'z', z.shape)  # DEBUG

        s = self.linear(m[..., 0, :, :])
        if self.log: print('\t' * self.depth + 'In EvoformerStack: after linear')  # DEBUG
        if self.log: print('\t' * self.depth + 's', s.shape)  # DEBUG

        if self.get_evoformer_embedding:
            return m, z, s, embedding_list

        return m, z, s


class ExtraMSAStack(nn.Module):
    """
    Implements Algorithm 18.
    """
    def __init__(self,
        c_m: int,
        c_z: int,
        c_hidden_msa_att: int,
        c_hidden_opm: int,
        c_hidden_mul: int,
        c_hidden_pair_att: int,
        no_heads_msa: int,
        no_heads_pair: int,
        no_blocks: int,
        transition_n: int,
        msa_dropout: float,
        pair_dropout: float,
        inf: float,
        eps: float,
        ckpt: bool,
        clear_cache_between_blocks: bool = False,
        no_extra_msa = False,
        **kwargs,
    ):
        super(ExtraMSAStack, self).__init__()

        self.no_extra_msa = no_extra_msa
 
        self.ckpt = ckpt
        self.clear_cache_between_blocks = clear_cache_between_blocks

        self.blocks = nn.ModuleList()
        for _ in range(no_blocks):
            block = ExtraMSABlock(
                c_m=c_m,
                c_z=c_z,
                c_hidden_msa_att=c_hidden_msa_att,
                c_hidden_opm=c_hidden_opm,
                c_hidden_mul=c_hidden_mul,
                c_hidden_pair_att=c_hidden_pair_att,
                no_heads_msa=no_heads_msa,
                no_heads_pair=no_heads_pair,
                transition_n=transition_n,
                msa_dropout=msa_dropout,
                pair_dropout=pair_dropout,
                inf=inf,
                eps=eps,
                ckpt=False,
                no_extra_msa=no_extra_msa,
            )
            self.blocks.append(block)

    def _prep_blocks(self, 
        m: torch.Tensor, 
        z: torch.Tensor, 
        msa_mask: Optional[torch.Tensor],
        pair_mask: Optional[torch.Tensor],
        inplace_safe: bool,
        _mask_trans: bool,
    ):
        blocks = [
            partial(
                b, 
                msa_mask=msa_mask, 
                pair_mask=pair_mask, 
                inplace_safe=inplace_safe,
                _mask_trans=_mask_trans,
            ) for b in self.blocks
        ]

        def clear_cache(b, *args, **kwargs):
            torch.cuda.empty_cache()
            return b(*args, **kwargs)

        if(self.clear_cache_between_blocks):
            blocks = [partial(clear_cache, b) for b in blocks]

        return blocks

    def forward(self,
        m: torch.Tensor,
        z: torch.Tensor,
        msa_mask: Optional[torch.Tensor],
        pair_mask: Optional[torch.Tensor],
        inplace_safe: bool = False,
        _mask_trans: bool = True,
    ) -> torch.Tensor:
        """
        Args:
            m: [*, N_extra, N_res, C_m] 
                extra MSA embedding
            z: [*, N_res, N_res, C_z] 
                pair embedding
            chunk_size: Inference-time subbatch size for Evoformer modules
            use_lma: Whether to use low-memory attention during inference
            msa_mask:
                Optional [*, N_extra, N_res] MSA mask
            pair_mask:
                Optional [*, N_res, N_res] pair mask
        Returns:
            [*, N_res, N_res, C_z] pair update
        """
        checkpoint_fn = get_checkpoint_fn()
        blocks = self._prep_blocks(
            m=m,
            z=z,
            msa_mask=msa_mask,
            pair_mask=pair_mask,
            inplace_safe=inplace_safe,
            _mask_trans=_mask_trans,
        )

        for b in blocks:
            if(self.ckpt and torch.is_grad_enabled()):
                m, z = checkpoint_fn(b, m, z)
            else:
                m, z = b(m, z)
                #print("m", m)  # DEBUG
                #print("z", z)  # DEBUG

        return z
