from math import sqrt

import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F


class VanillaGeometricReasoningOriginalImpl(nn.Module):
    def __init__(
        self,
        c_s: int,
        v_heads: int,
        num_vector_messages: int = 1,
        mask_and_zero_frameless: bool = True,
        divide_residual_by_depth: bool = False,
        bias: bool = False,
    ):
        """Approximate implementation:

        ATTN(A, v) := (softmax_j A_ij) v_j
        make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3)
        make_vectors(x) := T(i->g) Linear(x).reshape(..., 3)

        v <- make_rot_vectors(x)
        q_dir, k_dir <- make_rot_vectors(x)
        q_dist, k_dist <- make_vectors(x)

        A_ij       <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2
        x          <- x + Linear(T(g->i) ATTN(A, v))
        """
        super().__init__()
        self.c_s = c_s
        self.v_heads = v_heads
        self.num_vector_messages = num_vector_messages
        self.mask_and_zero_frameless = mask_and_zero_frameless

        self.s_norm = nn.LayerNorm(c_s, bias=bias)
        dim_proj = (
            4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages
        )  # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z)
        self.proj = nn.Linear(c_s, dim_proj, bias=bias)
        channels_out = self.v_heads * 3 * self.num_vector_messages
        self.out_proj = nn.Linear(channels_out, c_s, bias=bias)

        # The basic idea is for some attention heads to pay more or less attention to rotation versus distance,
        # as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues
        # very nearby or should there be shallower dropoff in attention weight?)
        self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
        self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))

    def forward(self, s, attention_mask, affine, affine_mask, sequence_id, chain_id):
        """
        s: [B * L, 16, d_model]
        attention_mask: [B * L, 16]
        affine: [B * L, 16]
        affine_mask: [B * L, 16]
        sequence_id: [B * L, 16]
        chain_id: [B * L, 16]
        """
        if sequence_id is None:
            sequence_id = torch.zeros_like(s[..., 0], dtype=torch.int64)

        
        #attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2) # [B * L, 16, 16]
        #attn_bias = attn_bias.unsqueeze(1).float()
        attn_bias = torch.zeros(attention_mask.shape[0], 1, attention_mask.shape[1], attention_mask.shape[1], device=s.device) # [B * L, 1, 16, 16]
        attn_bias = attn_bias.masked_fill(
            ~torch.logical_and(affine_mask, attention_mask)[:, None, None, :], torch.finfo(attn_bias.dtype).min
        )
        assert (~torch.logical_and(affine_mask, attention_mask) == ~affine_mask).all()
        chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2)
        attn_bias = attn_bias.masked_fill(
            chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min
        )

        ns = self.s_norm(s)
        vec_rot, vec_dist = self.proj(ns).split(
            [
                self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
                self.v_heads * 2 * 3,
            ],
            dim=-1,
        )

        # Rotate the queries and keys for the rotation term. We also rotate the values.
        # NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change
        # this in the future.
        query_rot, key_rot, value = (
            affine.rot[..., None]
            .apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
            .split(
                [
                    self.v_heads,
                    self.v_heads,
                    self.v_heads * self.num_vector_messages,
                ],
                dim=-2,
            )
        )

        # Rotate and translate the queries and keys for the distance term
        # NOTE(thayes): a simple speedup would be to apply all rotations together, then
        # separately apply the translations.
        query_dist, key_dist = (
            affine[..., None]
            .apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3))
            .chunk(2, dim=-2)
        )

        query_dist = rearrange(query_dist, "b s h d -> b h s 1 d")
        key_dist = rearrange(key_dist, "b s h d -> b h 1 s d")
        query_rot = rearrange(query_rot, "b s h d -> b h s d")
        key_rot = rearrange(key_rot, "b s h d -> b h d s")
        value = rearrange(
            value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages
        )

        distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3)
        rotation_term = query_rot.matmul(key_rot) / sqrt(3)
        distance_term_weight = rearrange(
            F.softplus(self.distance_scale_per_head), "h -> h 1 1"
        )
        rotation_term_weight = rearrange(
            F.softplus(self.rotation_scale_per_head), "h -> h 1 1"
        )

        attn_weight = (
            rotation_term * rotation_term_weight - distance_term * distance_term_weight
        )

        if attn_bias is not None:
            # we can re-use the attention bias from the transformer layers
            # NOTE(thayes): This attention bias is expected to handle two things:
            # 1. Masking attention on padding tokens
            # 2. Masking cross sequence attention in the case of bin packing
            s_q = attn_weight.size(2)
            s_k = attn_weight.size(3)
            _s_q = max(0, attn_bias.size(2) - s_q)
            _s_k = max(0, attn_bias.size(3) - s_k)
            attn_bias = attn_bias[:, :, _s_q:, _s_k:]
            # attn_weight: [B * L, 128, 16, 16]
            attn_weight = attn_weight + attn_bias

        attn_weight = torch.softmax(attn_weight, dim=-1)

        attn_out = attn_weight.matmul(value)

        attn_out = (
            affine.rot[..., None]
            .invert()
            .apply(
                rearrange(
                    attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages
                )
            )
        )

        attn_out = rearrange(
            attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages
        )
        if self.mask_and_zero_frameless:
            attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0)
        s = self.out_proj(attn_out)

        return s