import torch
from torch import nn

from framework.layers import LayerWithVisualization

from local_attention import LocalAttention
from .routing_transformer import SymmetricRoutingAttention
from .mosa_dir.pure_mosa import PureMoSA, FixedSA

class SlidingWindowAttention(nn.Module, LayerWithVisualization):
    def __init__(self, n_heads: int, window_size: int, h: int, h_prim: int,
                 bias: bool=False,
                  *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.n_heads, self.h, self.h_prim = n_heads, h, h_prim

        self.W_QKV = nn.Linear(h, h_prim*n_heads*3)
        self.W_O = nn.Linear(h_prim*n_heads, h)

        self.attention = LocalAttention(window_size=window_size, causal=True)

    def forward(self, X, **kwargs):
        # X [B, T, h]
        B, T, _ = X.shape
        QKV = self.W_QKV(X).reshape(B, T, self.n_heads, self.h_prim*3).transpose(1,2) # [B, n_heads, T, 3*h_prim]
        Q, K, V = torch.chunk(QKV, 3, dim=-1)


        att_out = self.attention(Q,K,V) # [B, n_heads, T, h_prim]
        att_out = att_out.transpose(1,2).reshape(B, T, self.n_heads*self.h_prim) #[ B, T, n_heads*h_prim]

        return self.W_O(att_out).reshape(B, T, self.h)

class RopePartheadSparseSlidingWindowAttention(nn.Module, LayerWithVisualization):
    def __init__(self, n_sw_heads: int, n_sparse_heads: int, sparsity: int, sparsity_type: str,
                 h:int, h_prim: int, max_seq_len: int,
                 rotate_fraction: float = 0.5, rope_base: float = 10000,
                 bias=False, QKV_hidden=None
                 ):
        super().__init__()

        self.n_sparse_heads = n_sparse_heads
        self.n_sw_heads = n_sw_heads

        window_size = int(max_seq_len // sparsity)
        if n_sparse_heads > 0:
            if sparsity_type == 'EC':
                print('='*20)
                print('symmetric EC attention')
                print('='*20)
                self.sparse_attention = PureMoSA(
                    n_sparse_heads, sparsity, h, h_prim, include_first=False, rotate_fraction=rotate_fraction, rope_base=rope_base
                )
            elif sparsity_type == 'routing':
                print('='*20)
                print('symmetric Routing attention')
                print('='*20)
                self.sparse_attention = SymmetricRoutingAttention(
                    num_clusters=int(sparsity), window_size=window_size, num_heads=n_sparse_heads, h=h, h_prim=h_prim
                )
            elif sparsity_type == 'fixed':
                print('='*20)
                print('fixed sparse attention')
                print('='*20)

                # it's called strided but it's actually fixed
                self.sparse_attention = FixedSA(
                    n_sparse_heads, sparsity, h, h_prim, include_first=False, rotate_fraction=rotate_fraction, rope_base=rope_base
                )
        else:
            self.sparse_attention = lambda x, **kwargs: 0
        if n_sw_heads > 0:
            self.sw_attention = SlidingWindowAttention(n_sw_heads, window_size, h, h_prim)
        else:
            self.sw_attention = lambda x,y,z,**kwargs: 0

    def forward(self, X: torch.Tensor, **kwargs):
        o_sparse = self.sparse_attention(X, **kwargs)
        o_dense = self.sw_attention(X, **kwargs)
        return o_sparse + o_dense, None