import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F
from timm.layers import use_fused_attn
from torch.jit import Final
from einops import rearrange, einsum
# from src.models.modules.simple_flash_attn_vit import RMSNorm
from src.models.modules.rmsnorm import RMSNorm
from src.utils.misc import default


def create_causal_mask(i, j, device):
    return ~torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)


def make_1dcoord(L, normalize=False):
    """
    Return(torch.Tensor): 1d coord values of shape [L, 1] 
    """
    x = np.arange(L, dtype=np.float32)   # [0, L)
    if normalize:
        x = x / L
    return torch.Tensor(x).reshape(L, 1)


def make_2dcoord(H, W, normalize=False):
    """
    Return(torch.Tensor): 2d coord values of shape [H, W, 2] 
    """
    x = np.arange(H, dtype=np.float32)   # [0, H)
    y = np.arange(W, dtype=np.float32)   # [0, W)
    if normalize:
        x = x / H
        y = y / W
    x_grid, y_grid = np.meshgrid(x, y, indexing='ij')
    return torch.Tensor(np.stack([x_grid.flatten(), y_grid.flatten()], -1).reshape(H, W, 2))


def make_SO2mats(coord, nfreqs):
    """
    Args:
      coord: [..., 1 or 2]
      freqs: [n_freqs, 1 or 2]
    Return:
      mats of shape [..., n_freqs, (1 or 2), 2, 2]
    """
    dim = coord.shape[-1]
    b = 10000.0
    freqs = torch.exp(torch.arange(0., 2*nfreqs, 2) *
                      -(math.log(b) / (2*nfreqs)))
    grid_ths = [torch.einsum(
        '...i,j->...ij', coord[..., d:d+1], freqs).flatten(-2, -1) for d in range(dim)]

    _mats = [[torch.cos(grid_ths[d]), -torch.sin(grid_ths[d]),
              torch.sin(grid_ths[d]), torch.cos(grid_ths[d])] for d in range(dim)]
    mats = [rearrange(torch.stack(_mats[d], -1),
                      '... (h w)->... h w', h=2, w=2) for d in range(dim)]
    mat = torch.stack(mats, -3)
    # print('so2 dtype:', mat.dtype)
    return mat

# GTA
@torch.jit.script
def rep_mul_x(rep, x):
    #  rep.shape=[T, F, 2, 2], x.shape=[B, H, T, F*2]
    shape = x.shape
    return (rep[None, None] * (x.unflatten(-1, (-1, 2))[..., None, :])).sum(-1).view(shape)


@torch.jit.script
def rep_mul_qkv(rep, q, k, v):
    return rep_mul_x(rep, q), rep_mul_x(rep, k), rep_mul_x(rep, v)


@torch.jit.script
def rep_mul_qk(rep, q, k):
    return rep_mul_x(rep, q), rep_mul_x(rep, k)


class Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool | str = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = RMSNorm,
            gta: str = '',
            resolutions=[16, 16],
            use_causal_mask = False,
            inverted = False, # for inverted attention https://openreview.net/pdf?id=m9s6rnYWqm
            lifted: bool = False,
            simple_mixing: bool = False,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()
        self.use_causal_mask = use_causal_mask

        mn = 2 * len(resolutions)
        assert self.head_dim % mn == 0, 'head_dim should be divisible by m*n, where m=2 is the SO2 action dim and n=len(resolutions) is the coordinate dim'

        self.gta = gta
        self.inverted = inverted
        self.lifted = lifted
        # self.v_transform = v_transform

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        q_norm = qk_norm == 'q' or qk_norm == True
        k_norm = qk_norm == 'k' or qk_norm == True
        self.q_norm = norm_layer(self.head_dim) if q_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if k_norm else nn.Identity()
        # self.k_norm = norm_layer(self.head_dim) if qk_norm and not inverted else nn.Identity()
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        if simple_mixing:
            self.proj = nn.Identity()
            self.v = nn.Identity()

        if gta:

            F = self.head_dim // mn
            if len(resolutions) == 1:
                coord = make_1dcoord(resolutions[0])
                rep = make_SO2mats(coord, F).flatten(2, 3)
            elif len(resolutions) == 2:
                coord = make_2dcoord(resolutions[0], resolutions[1])
                rep = make_SO2mats(coord, F).flatten(
                    2, 3).flatten(0, 1)  # [h*w, d, 2, 2]
            else:
                raise ValueError("len(resolutions) must be 1 or 2")
            self.register_buffer('so2rep', rep)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
                                  self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)
        x = self._forward(q, k, v)
        
        x = x.transpose(1, 2)
        x = x.reshape(B, N, N, C) if self.lifted else x.reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
    def _forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        # q, k, v: [batch, head, token, head_dim]
            
        # RoPE or GTA. Apply ρ^-1.
        if self.gta:
            rep = self.so2rep  # [T, F, 2, 2] or [T, F, 2, 2, 2]
            if self.gta == 'qkv':   # GTA
                q, k, v = rep_mul_qkv(rep, q, k, v)
            elif self.gta == 'qk':  # RoPE
                q, k = rep_mul_qk(rep, q, k)
            elif self.gta == 'kv':  # cross attn ver of GTA
                k, v = rep_mul_qk(rep, k, v)
            elif self.gta == 'q':
                q = rep_mul_x(rep, q)
            elif self.gta == 'k':
                k = rep_mul_x(rep, k)
            
            # if self.v_transform:  # GTA
            #     q, k, v = rep_mul_qkv(rep, q, k, v)
            # else:  # RoPE
            #     q, k = rep_mul_qk(rep, q, k)

        if self.fused_attn and not self.inverted and not self.lifted:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
                is_causal = self.use_causal_mask,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            if self.use_causal_mask:
                q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
                attn_mask = create_causal_mask(q_len, k_len, device = device)
                attn = attn.masked_fill(~attn_mask, float('-inf'))

            if self.inverted:
                attn = attn.softmax(dim=-2)   # attention normalized by queries 
                scale = 1. / (attn.sum(dim=-1, keepdim=True) + 1e-6)
                attn = attn * scale  # renormalization
            else:   
                attn = attn.softmax(dim=-1)   # attention normalized by keys (vanilla attention)
            attn = self.attn_drop(attn)

            if self.lifted:
                x = einsum(attn, v, "b h n m, b h m d -> b h n m d") 
            else: 
                x = attn @ v

        # GTA. Apply ρ.
        # if self.gta and self.v_transform:
        #     x = rep_mul_x(rep.transpose(-2, -1), x)
        
        if self.gta == 'qkv' or ('v' in self.gta and self.lifted):
            x = rep_mul_x(rep.transpose(-2, -1), x)

        if self.lifted:
            x = rearrange(x, "b h n m d -> b h (n m) d") 
        
        return x


class CrossAttention(Attention):
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            raw_q = False,
            return_qkk = False,
            **kwargs
    ) -> None:
        super().__init__(dim, num_heads, qkv_bias, **kwargs)
        del self.qkv
        self.q = nn.Linear(dim, dim, bias=qkv_bias) if not raw_q else nn.Identity()
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias) 
        self.return_qkk = return_qkk

    def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        '''
        x: torch.Tensor, shape=[B, N, C]
        y: torch.Tensor, shape=[B, M, C]
        z: torch.Tensor, shape=[B, M, C] or None
        return torch.Tensor, shape=[B, N, C]
        '''
        B, N, C = x.shape
        _, M, _ = y.shape
        z = default(z, y)
        assert z.shape[1] == M, 'y and z must have the same length'

        q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k(y).reshape(B, M, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v(z).reshape(B, M, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        q, k = self.q_norm(q), self.k_norm(k)

        x = self._forward(q, k, v)
        
        x = x.transpose(1, 2)
        x = x.reshape(B, N, M, C) if self.lifted else x.reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        if self.return_qkk:
            qkk = self._forward(q, k, k)
            qkk = qkk.reshape(B, N, C)
            return x, qkk

        return x
