import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Callable, Optional
from collections import OrderedDict
from torch import Tensor

from open_clip.transformer import LayerNorm, LayerScale

class CustomMultiheadAttention(nn.Module):
    def __init__(self, q_dim: int, kv_dim: int, out_dim: int, num_heads: int):
        super().__init__()
        self.in_q = nn.Linear(q_dim, out_dim)
        self.in_k = nn.Linear(kv_dim, out_dim)
        self.in_v = nn.Linear(kv_dim, out_dim)
        self.out  = nn.Linear(out_dim, out_dim)

        self.out_dim = out_dim
        self.n_heads = num_heads
        self.out_dim_per_head = self.out_dim // self.n_heads
    
    def split_heads(self, t: Tensor) -> Tensor:
        if len(t.shape) == 3:
            return t.view(t.shape[0], t.shape[1], self.n_heads, self.out_dim_per_head).permute(0, 2, 1, 3)
        elif len(t.shape) == 2:
            return t.view(t.shape[0], self.n_heads, self.out_dim_per_head)

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        q = self.in_q(q)
        k = self.in_k(k)
        v = self.in_v(v)

        q, k, v = [self.split_heads(t) for t in (q, k, v)]
        # q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]

        qk = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.out_dim_per_head)
        qk = F.softmax(qk, dim=-1)
        qkv = torch.matmul(qk, v)
        qkv = qkv.transpose(1, 2).contiguous()
        if len(qkv.shape) == 4:
            qkv = qkv.view(q.shape[0], -1, self.out_dim)
        return self.out(qkv)


class CustomCrossAttention(nn.Module):
    def __init__(self, q_dim, kv_dim, out_dim, num_heads):
        super().__init__()
        self.attention = CustomMultiheadAttention(q_dim, kv_dim, out_dim, num_heads)
        self.ln_q = nn.LayerNorm(q_dim)
        self.ln_kv = nn.LayerNorm(kv_dim)

    def forward(self, q: Tensor, kv: Tensor) -> Tensor:
        """
        q  : [batch, num_token_q, dim]
        kv : [batch, num_token_kv, dim]
        ->
        out: [batch, num_token_q, dim]
        """
        x = self.attention(self.ln_q(q), self.ln_kv(kv), self.ln_kv(kv))
        return x


class CustomSelfAttention(nn.Module):
    def __init__(self, dim, out_dim, num_heads):
        super().__init__()
        self.attention = CustomMultiheadAttention(dim, dim, out_dim, num_heads)
        self.ln_qkv = nn.LayerNorm(dim)

    def forward(self, qkv: Tensor) -> Tensor:
        """
        qkv: [batch, num_token, dim]
        ->
        out: [batch, num_token, dim]
        """
        x = self.attention(self.ln_qkv(qkv), self.ln_qkv(qkv), self.ln_qkv(qkv))
        return x


class CustomSelfCrossAttentionLayer(nn.Module):
    def __init__(self, q_dim, kv_dim, out_dim, num_heads):
        super().__init__()
        self.self_attention = CustomSelfAttention(q_dim, q_dim, num_heads)
        self.cross_attention = CustomCrossAttention(q_dim, kv_dim, out_dim, num_heads)
        self.ln_out = nn.LayerNorm(out_dim)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(out_dim, out_dim * 4)),
            ("gelu", nn.GELU()),
            ("c_proj", nn.Linear(out_dim * 4, out_dim))
        ]))

    def forward(self, q: Tensor, kv: Tensor) -> Tensor:
        """
        q  : [batch, num_token_q, dim]
        kv : [batch, num_token_kv, dim]
        ->
        out: [batch, num_token_q, dim]
        """
        q = q + self.self_attention(q)
        x = q + self.cross_attention(q, kv)
        x = x + self.mlp(self.ln_out(x))
        return x
