import math
import torch
from torch import nn

from coarsebind_public.mol_encoder.models.loose_modules.activations import NewGELU
from coarsebind_public.mol_encoder.models.loose_modules.norms import RMSNorm


class SelfAttention(nn.Module):
    """
    Simple Self-Attention
    """

    def __init__(self, dim_embed, nhead):
        super().__init__()
        assert dim_embed % nhead == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(dim_embed, 3 * dim_embed)
        # output projection
        self.c_proj = nn.Linear(dim_embed, dim_embed)
        self.n_head = nhead
        self.n_embd = dim_embed

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = torch.nn.functional.softmax(att, dim=-1)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = (
            y.transpose(1, 2).contiguous().view(B, T, C)
        )  # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y


class MaskedSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)
        self.n_head = n_head
        self.n_embd = n_embd

    def forward(self, x, pre_attn_mask):
        """
        mask is a multiplicative mask. 1 = attended. 0 = static.
        """
        assert (pre_attn_mask.sum(-1) > 0).all()

        # get attn_mask into right shape
        attn_mask = torch.logical_not(
            torch.logical_or(
                torch.eye(pre_attn_mask.shape[1], device=x.device, dtype=x.dtype)
                .unsqueeze(0)
                .repeat(pre_attn_mask.shape[0], 1, 1),
                pre_attn_mask.unsqueeze(2) * pre_attn_mask.unsqueeze(1),
            )
        )
        attn_mask = attn_mask.unsqueeze(1)
        attn_mask = attn_mask.repeat(1, self.n_head, 1, 1)
        (
            B,
            T,
            C,
        ) = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

        att = att.masked_fill(attn_mask, float("-inf"))
        att = nn.functional.softmax(att, dim=-1).clone().nan_to_num_(nan=0.0)

        assert not att.isnan().all()
        # print(att)
        assert not v.isnan().all()

        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = torch.where(
            pre_attn_mask.unsqueeze(-1).repeat(1, 1, x.shape[-1]),
            y.transpose(1, 2).contiguous().view(B, T, C),
            x,
        )

        # output projection
        y = self.c_proj(y)
        return torch.where(pre_attn_mask.unsqueeze(-1).repeat(1, 1, y.shape[-1]), y, x)


class AttentionBlock(nn.Module):
    """A _n-causal_ Self-Attention Block."""

    def __init__(self, dim_embed, nhead, rmsnorm=False):
        super().__init__()
        if rmsnorm:
            self.ln_1 = RMSNorm(dim_embed)
        else:
            self.ln_1 = nn.LayerNorm(dim_embed)
        self.attn = SelfAttention(dim_embed, nhead)
        if rmsnorm:
            self.ln_2 = RMSNorm(dim_embed)
        else:
            self.ln_2 = nn.LayerNorm(dim_embed)
        self.mlpf = nn.Sequential(
            nn.Linear(dim_embed, 4 * dim_embed),
            NewGELU(),
            nn.Linear(4 * dim_embed, dim_embed),
        )

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x


class MaskedAttentionBlock(nn.Module):
    """A _n-causal_ Self-Attention Block."""

    def __init__(self, dim_embed, nhead, rmsnorm=False):
        super().__init__()
        if rmsnorm:
            self.ln_1 = RMSNorm(dim_embed)
        else:
            self.ln_1 = nn.LayerNorm(dim_embed)
        self.attn = MaskedSelfAttention(dim_embed, nhead)
        if rmsnorm:
            self.ln_2 = RMSNorm(dim_embed)
        else:
            self.ln_2 = nn.LayerNorm(dim_embed)
        self.mlpf = nn.Sequential(
            nn.Linear(dim_embed, 4 * dim_embed),
            NewGELU(),
            nn.Linear(4 * dim_embed, dim_embed),
        )

    def forward(self, x, mask):
        """
        mask: batch X seq X seq multiplicative true == attn allowed.
        """
        x = x + self.attn(self.ln_1(x), mask)
        x = x + self.mlpf(self.ln_2(x))
        return x


class AttAgg(nn.Module):
    def __init__(self, n_embd, n_head, n_layers=2, n_out_tokens=1, rmsnorm=True):
        super().__init__()
        self.blocks = torch.nn.ModuleList(
            [
                torch.nn.ModuleList(
                    [MaskedAttentionBlock(n_embd, n_head, rmsnorm=rmsnorm) for _ in range(n_layers)]
                )
                for __ in range(n_out_tokens)
            ]
        )
        self.n_out_tokens = n_out_tokens

    def forward(self, x_, mask_=None):
        """
        extracts a [stop] from n_token transformers.
        Args:
            x_: batch X max_tokens X n_embd
            mask_: batch X max_tokens multiplicative mask.
        """
        outputs = []

        assert not x_.isnan().any()

        for branch in self.blocks:
            x = x_.clone()
            for block in branch:
                x = block(x, mask_)
                assert not x.isnan().any()

            # branch_output = x.sum(1)/((mask_.sum(1)+1).sqrt().unsqueeze(-1))
            branch_output = (x * mask_.unsqueeze(2)).sum(1) / (
                (mask_.sum(1) + 1).sqrt().unsqueeze(-1)
            )
            assert not branch_output.isnan().any()
            outputs.append(branch_output)

        tore = torch.cat(outputs, -1)

        # import numpy, sys
        # numpy.set_printoptions(threshold=sys.maxsize)
        # print('DEBUG stopagg', tore.mean(0).mean(),  tore.std(0).mean(), tore.shape)
        # print(tore)

        return tore
