"""
2023.03.05 init
"""
from functools import partial
from typing import Optional

import torch
import torch.nn as nn

from myopenfold.model.primitives import Linear

from myopenfold.utils.precision_utils import is_fp16_enabled


class OuterProductMean(nn.Module):
    """
    Implements Algorithm 10.
    """

    def __init__(self, c_m, c_z, c_hidden, eps=1e-3, depth=0, ind=0, log=False):
        """
        Args:
            c_m:
                MSA embedding channel dimension
            c_z:
                Pair embedding channel dimension
            c_hidden:
                Hidden channel dimension
            .depth:
                Depth of this module in the whold model
            .ind:
                Index of this block in the stack
            .log
                Whether print some log information
        """
        super(OuterProductMean, self).__init__()

        self.depth = depth
        self.ind = ind
        self.log = log

        self.c_m = c_m
        self.c_z = c_z
        self.c_hidden = c_hidden
        self.eps = eps

        self.layer_norm = nn.LayerNorm(c_m)
        self.linear_1 = Linear(c_m, c_hidden)
        self.linear_2 = Linear(c_m, c_hidden)
        self.linear_out = Linear(c_hidden ** 2, c_z, init="final")

    def _opm(self, a, b):
        #. a [*, N_res, N_seq, C]
        #. b [*, N_res, N_seq, C]

        # [*, N_res, N_res, C, C]
        outer = torch.einsum("...bac,...dae->...bdce", a, b)

        # [*, N_res, N_res, C * C]
        outer = outer.reshape(outer.shape[:-2] + (-1,))

        # [*, N_res, N_res, C_z]
        outer = self.linear_out(outer)

        return outer

    def _forward(self, 
        m: torch.Tensor, 
        mask: Optional[torch.Tensor] = None,
        inplace_safe: bool = False,
    ) -> torch.Tensor:
        """
        Args:
            m:
                [*, N_seq, N_res, C_m] MSA embedding
            mask:
                [*, N_seq, N_res] MSA mask
        Returns:
            [*, N_res, N_res, C_z] pair embedding update
        """
        if mask is None:
            mask = m.new_ones(m.shape[:-1])

        # [*, N_seq, N_res, C_m]
        ln = self.layer_norm(m)

        # [*, N_seq, N_res, 1]
        mask = mask.unsqueeze(-1)

        #. a [*, N_seq, N_res, C]
        a = self.linear_1(ln) 
        a = a * mask
        
        #. b [*, N_seq, N_res, C]
        b = self.linear_2(ln) 
        b = b * mask

        del ln

        #. a [*, N_res, N_seq, C]
        #. b [*, N_res, N_seq, C]
        a = a.transpose(-2, -3)
        b = b.transpose(-2, -3)

        outer = self._opm(a, b)

        #. the outer product is computed rowwise over the feature dimension
        # [*, N_res, N_res, 1]
        norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
        norm = norm + self.eps

        # [*, N_res, N_res, C_z]
        if(inplace_safe):
            outer /= norm
        else:
            outer = outer / norm

        return outer

    def forward(self,
                m: torch.Tensor,
                mask: Optional[torch.Tensor] = None,
                chunk_size: Optional[int] = None,
                inplace_safe: bool = False,
    ) -> torch.Tensor:
        if self.ind == 0 and self.log: print('\t' * self.depth + "In OuterProductMean: init")  # DEBUG
        if(is_fp16_enabled()):
            #. disable dtype autocast
            with torch.cuda.amp.autocast(enabled=False):
                return self._forward(m.float(), mask, inplace_safe)
        else:
            return self._forward(m, mask, inplace_safe)
        
