"""
2023.03.05 init
"""
from functools import partialmethod, partial
import math
from typing import Optional, List

import torch
import torch.nn as nn

#. model
from myopenfold.model.primitives import Linear, LayerNorm, Attention

#. utils
from myopenfold.utils.tensor_utils import (
    permute_final_dims,
    flatten_final_dims,
)


class TriangleAttention(nn.Module):
    def __init__(
        self, c_in, c_hidden, no_heads, starting=True, inf=1e9, depth=0, ind=0, log=False
    ):
        """
        Args:
            c_in:
                Input channel dimension
            c_hidden:
                Overall hidden channel dimension (not per-head)
            no_heads:
                Number of attention heads
            .depth:
                Depth of this module in the whold model
            .ind:
                Index of this block in the stack
            .log
                Whether print some log information
        """
        super(TriangleAttention, self).__init__()

        self.depth = depth
        self.ind = ind
        self.log = log

        self.c_in = c_in
        self.c_hidden = c_hidden
        self.no_heads = no_heads
        self.starting = starting
        self.inf = inf

        self.layer_norm = LayerNorm(self.c_in)

        self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")

        self.mha = Attention(
            self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
        )

    def forward(self, 
        x: torch.Tensor, 
        mask: Optional[torch.Tensor] = None,
        use_memory_efficient_kernel: bool = False,
        inplace_safe: bool = False,
    ) -> torch.Tensor:
        """
        Args:
            x:
                [*, I, J, C_in] input tensor (e.g. the pair representation)
        Returns:
            [*, I, J, C_in] output tensor
        """ 
        if mask is None:
            # [*, I, J]
            mask = x.new_ones(
                x.shape[:-1],
            )

        if(not self.starting):
            x = x.transpose(-2, -3)
            mask = mask.transpose(-1, -2)

        # [*, I, J, C_in]
        x = self.layer_norm(x)

        #. 1 -> 0, 0 -> -inf
        # [*, I, 1, 1, J]
        mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]

        #. bias is a scalar for each i,j,h, since it will be added to q dot k
        # [*, H, I, J]
        triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))

        # [*, 1, H, I, J]
        triangle_bias = triangle_bias.unsqueeze(-4)

        biases = [mask_bias, triangle_bias]


        x = self.mha(
            q_x=x, 
            kv_x=x, 
            biases=biases, 
            use_memory_efficient_kernel=use_memory_efficient_kernel,
        )
        #. around starting node
        #. q_x [*, I, J, C_in]
        #. kv_x [*, I, J, C_in]
        #. mask_bias [*, I, 1, 1, J]  triangle_bias [*, 1, H, I, J]
        #. q [*, I, H, J, C_hidden]
        #. k [*, I, H, J, C_hidden]
        #. v [*, I, H, J, C_hidden]  end of _prep_qkv
        #. a [*, I, H, J(query), J(key)]
        #. a [*, I, H, J(query), J(key)] + mask_bias [*, I, 1, 1, J] + triangle_bias [*, 1, H, I, J]
        #. a_ihjk = q_ihj dot k_ihk + mask_bias_ik + triangle_bias_hjk
        #. o = a matmul v = [*, I, H, J, C_hidden]
        #. o [*, I, J, H, C_hidden]
        #. g = sigmoid(liner(q_x)) [*, I, J, H * C_hidden]
        #. g [*, I, J, H, C_hidden]
        #. final = [*, I, J, C_in]

        #. around ending node
        #. the following note is actually unnecessary.
        #. basically, i and j are swapped at the input and swapped back at the output
        #. this is exactly what you expect for changing Algorithm 13 to Algorithm 14
        
        # . q_x [*, J, I, C_in]
        # . kv_x [*, J, I, C_in]
        # . mask_bias [*, J, 1, 1, I]  triangle_bias [*, 1, H, J, I]
        # . q [*, J, H, I, C_hidden]
        # . k [*, J, H, I, C_hidden]
        # . v [*, J, H, I, C_hidden]  end of _prep_qkv
        # . a [*, J, H, I(query), I(key)]
        # . a [*, J, H, I(query), I(key)] + mask_bias [*, J, 1, 1, I] + triangle_bias [*, 1, H, J, I]
        # . a_jhik = q_jhi dot k_hik + mask_bias_jk + triangle_bias_hik
        # . o = a matmul v = [*, J, H, I, C_hidden]
        # . o [*, J, I, H, C_hidden]
        # . g = sigmoid(liner(q_x)) [*, J, I, H * C_hidden]
        # . g [*, J, I, H, C_hidden]
        # . final = [*, J, I, C_in] -> [*, I, J, C_in]


        if(not self.starting):
            x = x.transpose(-2, -3)

        return x


# Implements Algorithm 13
TriangleAttentionStartingNode = TriangleAttention


class TriangleAttentionEndingNode(TriangleAttention):
    """
    Implements Algorithm 14.
    """
    __init__ = partialmethod(TriangleAttention.__init__, starting=False)
