
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.attention.flex_attention import flex_attention, BlockMask

from timm.layers import Mlp

from rotary import RotaryEmbedding

import math

def norm(x):
    return F.rms_norm(x, (x.size(-1),))

def modulate(x, shift, scale):
    return x * (1 + scale) + shift

class UnifiedTransformerBlock(nn.Module):
    def __init__(self, *, n_channels, n_heads, mlp_factor, dropout=0.1, attn_backend='flex', n_channels_pair=None, pos_embed=False):
        super().__init__()
        self.attn = SelfAttention(dim=n_channels, n_heads=n_heads, dropout=dropout, attn_backend=attn_backend, pos_embed=pos_embed)
        self.mlp = Mlp(
            n_channels, n_channels * mlp_factor, n_channels, act_layer=nn.GELU, norm_layer=None, drop=dropout
        )

    def forward(self, s_BLD, **attn_params):
        s_BLD = s_BLD + self.attn(norm(s_BLD), **attn_params)
        s_BLD = s_BLD + self.mlp(norm(s_BLD))
        return s_BLD

class TransformerBlockSPDA(nn.Module):
    def __init__(
        self, *, n_channels, n_heads, mlp_factor, dropout=0.1, pair_bias=False, n_channels_pair=None, pos_embed=False
    ):
        super().__init__()
        self.attn = SelfAttention(dim=n_channels, n_heads=n_heads, dropout=dropout, attn_backend='spda', pos_embed=pos_embed)
        self.mlp = Mlp(
            n_channels, n_channels * mlp_factor, n_channels, act_layer=nn.GELU, norm_layer=None, drop=dropout
        )
   
  
    def forward(self, s_BLD, **attn_params):
        """
        c_BL is time conditioning. eventually the other conditioning is concatenated on. 
        """
        s_BLD = s_BLD + self.attn(norm(s_BLD), **attn_params)
        s_BLD = s_BLD + self.mlp(norm(s_BLD))
        return s_BLD



class TransformerStack(nn.Module):
    def __init__(
        self, 
        *, 
        n_channels, 
        n_heads, 
        mlp_factor, 
        window_size: int,
        use_flex_attn=True, 
        n_layers: int, 
        dropout=0.1,
        n_channels_pair=None,
        is_causal=False,
    ):
        """ 
        is this class doing anything useful?
        """
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(
                n_channels=n_channels, n_heads=n_heads, mlp_factor=mlp_factor, use_flex_attn=use_flex_attn, dropout=dropout, pair_bias=n_channels_pair > 0, n_channels_pair=n_channels_pair) 
                for _ in range(n_layers)
        ])
        self.norm = norm
        self.window_size = window_size
        self.is_causal = is_causal

    # a range of useful mask mods
    @classmethod
    def sliding_window(cls, window_size):
        def _sliding_window(b, h, q_idx, kv_idx):
            return abs(q_idx - kv_idx) <= window_size
        return _sliding_window
    
    @classmethod
    def sequence_packed(cls, doc_ids):
        def _sequence_packed(b, h, q_idx, kv_idx):
            return doc_ids[q_idx] == doc_ids[kv_idx]
        return _sequence_packed
    

    def forward(self, s_BLD, block_mask: BlockMask, pair_bias_BLLD=None):
        # batch is broadcasted
        # get device real quick
        device = next(self.parameters()).device

        for block in self.blocks:
            s_BLD = block(s_BLD, block_mask, pair_bias_BLLD=pair_bias_BLLD)
        return s_BLD


class TransformerBlock(nn.Module):
    def __init__(
        self, *, n_channels, n_heads, mlp_factor, use_flex_attn=True, dropout=0.1, pair_bias=False, n_channels_pair=None
    ):
        super().__init__()
        # i know this says causal but it supports other types, just pass the mask correctly
        self.attn_backend = 'flex' if use_flex_attn else 'spda'
        self.attn = SelfAttention(dim=n_channels, n_heads=n_heads, dropout=dropout, 
        attn_backend=self.attn_backend)
        self.mlp = Mlp(
            n_channels, n_channels * mlp_factor, n_channels, act_layer=nn.GELU, norm_layer=None, drop=dropout
        )

        self.pair_bias = pair_bias
        if pair_bias:
            self.proj_pair_bias_W1 = nn.Linear(n_channels_pair, n_channels_pair, bias=True)
            self.proj_pair_bias = nn.Linear(n_channels_pair, 1, bias=True)
            self.pair_ln = nn.LayerNorm(n_channels_pair)
        else:
            self.proj_pair_bias = None
        
    def get_score_mod(self, pair_bias_BLLD):
        pair_bias_BLLD = F.silu(self.proj_pair_bias_W1(pair_bias_BLLD))
        pair_bias_BLL = self.proj_pair_bias(self.pair_ln(pair_bias_BLLD)).squeeze(-1)

        def pair_biased_attn(score, b, h, q_idx, kv_idx):
            return score + pair_bias_BLL[b, q_idx, kv_idx]

        return pair_biased_attn
    
    def _init_weights(self, module):
        # i think we can zero init the pair bias
        with torch.no_grad():
            self.proj_pair_bias.weight.data.zero_()
            self.proj_pair_bias.bias.data.zero_()
            self.proj_pair_bias_W1.weight.data.zero_()
            self.proj_pair_bias_W1.bias.data.zero_()
   
    def forward(self, s_BLD, block_mask, pair_bias_BLLD):
        """
        c_BL is time conditioning. eventually the other conditioning is concatenated on. 
        """
        score_mod = None
        if self.pair_bias:
            score_mod = self.get_score_mod(pair_bias_BLLD)

        s_BLD = s_BLD + self.attn(norm(s_BLD), block_mask=block_mask, score_mod=score_mod)
        s_BLD = s_BLD + self.mlp(norm(s_BLD))
        return s_BLD

class CrossAttention(nn.Module):
    def __init__(self, dim: int, n_heads: int, dropout=0.1):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.attn_dropout = nn.Dropout(dropout)
        self.dropout = dropout
        self.dim_per_head = dim // n_heads

        self.to_q = nn.Linear(dim, dim, bias=True)
        self.to_kv = nn.Linear(dim, 2 * dim, bias=True)

        self.to_out = nn.Linear(dim, dim)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, s_BLD, c_BLD, block_mask=None, score_mod=None):
        B, L, D = s_BLD.shape
        q = self.to_q(s_BLD)
        k, v = self.to_kv(c_BLD).split(self.dim, dim=-1)

        q = q.view(B, L, self.n_heads, self.dim_per_head).transpose(1, 2)
        k = k.view(B, L, self.n_heads, self.dim_per_head).transpose(1, 2)
        v = v.view(B, L, self.n_heads, self.dim_per_head).transpose(1, 2)

        y = flex_attention(q, k, v, block_mask=block_mask, score_mod=score_mod)
        y = self.attn_dropout(y)
        y = y.transpose(1, 2).contiguous().view(B, L, self.dim)
        y = self.resid_dropout(self.to_out(y))
        return y

# TODO: add bias parameter by hacking attn_mask and passing explicitly
class SelfAttention(nn.Module):
    def __init__(
        self, dim: int, n_heads: int, attn_backend='flex', dropout=0.1, pos_embed: bool = True
    ):
        super().__init__()
        # NOTE: this is a good place to implement pair bias via score mod
        self.dim = dim
        self.n_heads = n_heads
        assert self.dim % self.n_heads == 0
        self.dim_per_head = self.dim // self.n_heads
        self.attn_dropout = nn.Dropout(dropout)
        self.dropout = dropout
        self.attn_backend = attn_backend

        if self.attn_backend == 'flex':
            self.attn = flex_attention
        elif self.attn_backend == 'spda':
            self.attn = F.scaled_dot_product_attention
        else:
            raise ValueError(f"Invalid attention backend: {self.attn_backend}")
        

        if pos_embed:
            self.pos_embed = RotaryEmbedding(self.dim_per_head)
        else:
            self.pos_embed = None
        self.to_qkv = nn.Linear(dim, 3 * dim, bias=True)
        self.to_out = nn.Linear(dim, dim)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x_BLD, **attn_params):
        # incorporate pair bias via score_mod
        B, L, D = x_BLD.shape

        q, k, v = self.to_qkv(x_BLD).split(self.dim, dim=-1)
        q = q.view(B, L, self.n_heads, self.dim_per_head).transpose(1, 2)
        k = k.view(B, L, self.n_heads, self.dim_per_head).transpose(1, 2)
        v = v.view(B, L, self.n_heads, self.dim_per_head).transpose(1, 2)


        if self.pos_embed is not None:
            q, k = self.pos_embed(q, k)
        
        if self.attn_backend in ['spda', 'flex']:
            y = self.attn(q, k, v, **attn_params)
        else:
            # if not pair bias, do attention manually
            raise ValueError("This may not be the best idea")
            sim_BHLL = q @ k.transpose(-2, -1) / math.sqrt(self.dim_per_head)
            sim_BHLL = sim_BHLL + pair_bias_BLL.unsqueeze(-3)
            if self.is_causal:
                mask_LL = torch.ones((L, L), device=sim_BHLL.device, dtype=bool)
                sim_BHLL.masked_fill_(torch.triu(mask_LL, diagonal=1), float('-inf'))

            sim_BHLL = sim_BHLL.softmax(dim=-1)
            y = sim_BHLL @ v

        y = self.attn_dropout(y)
        y = y.transpose(1, 2).contiguous().view(B, L, self.dim)
        y = self.resid_dropout(self.to_out(y))
        return y


