import torch.nn as nn
import open_clip
import torch
import torch.nn.functional as F
from imagebind.models import imagebind_model

import math

class ScaledDotProductAttention(nn.Module):

    def forward(self, query, key, value, mask=None):
        dk = query.size()[-1]
        scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(dk)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention = F.softmax(scores, dim=-1)
        return attention.matmul(value)

class PlainMultiHeadAttention(nn.Module):
    def __init__(
            self,
            embed_dim=768,
            num_heads=12,
            dropout=0.,
            bias=True,
            kdim=None,
            vdim=None,
            batch_first=False,
            add_bias_kv=False):
        super().__init__()
        
        self.add_bias_kv = add_bias_kv
        if self.add_bias_kv:
            self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim)))
            self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim)))

        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if not self._qkv_same_embed_dim:
            assert NotImplementedError
        else:
            self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=bias)
        # self.scaled_dot_product_attention = ScaledDotProductAttention
        self.scaled_dot_product_attention = F.scaled_dot_product_attention

        self.proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def init_weights(self):
        pass

    def forward(
            self,
            query,
            key,
            value,
            key_padding_mask=None,
            need_weights=True,
            attn_mask=None,
            average_attn_weights=True,
            is_causal=False):

        if attn_mask is not None and is_causal:
            raise AssertionError("Only allow causal mask or attn_mask")
        is_batched = query.dim() == 3
        key_padding_mask = F._canonical_mask(
            mask=key_padding_mask,
            mask_name="key_padding_mask",
            other_type=F._none_or_dtype(attn_mask),
            other_name="attn_mask",
            target_type=query.dtype
        )
        if self.batch_first and is_batched:
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

        tgt_len, bsz, embed_dim = query.shape
        src_len, _, _ = key.shape

        E = query.size(-1)
        qkv = self.qkv(query)
        qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
        q, k, v = qkv[0], qkv[1], qkv[2]
        # if self.add_bias_kv:
        #     k += self.bias_k
        #     v += self.bias_v
        if self.add_bias_kv:
            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])


        attn_mask = F._canonical_mask(
            mask=attn_mask,
            mask_name="attn_mask",
            other_type=F._none_or_dtype(key_padding_mask),
            other_name="key_padding_mask",
            target_type=q.dtype,
            check_other=False,
        )

        if attn_mask is not None:
            # ensure attn_mask's dim is 3
            if attn_mask.dim() == 2:
                correct_2d_size = (tgt_len, src_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(
                        f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
                attn_mask = attn_mask.unsqueeze(0)
            elif attn_mask.dim() == 3:
                correct_3d_size = (bsz * self.num_heads, tgt_len, src_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(
                        f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
            else:
                raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

        if attn_mask is not None:
            if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
                attn_mask = attn_mask.unsqueeze(0)
            else:
                attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len)

        dropout_p = self.dropout if self.training else 0.

        q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if self.add_bias_kv:
            k = k.view(k.shape[0], bsz * self.num_heads, self.head_dim).transpose(0, 1)
            v = v.view(v.shape[0], bsz * self.num_heads, self.head_dim).transpose(0, 1)
        else:
            k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
            v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        if self.add_bias_kv:
            src_len = k.size(1)
        q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
        k = k.view(bsz, self.num_heads, src_len, self.head_dim)
        v = v.view(bsz, self.num_heads, src_len, self.head_dim)

        attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
        # attn_output = ScaledDotProductAttention()(q, k, v, attn_mask)
        attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
        attn_output = self.proj(attn_output)
        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), None
        return attn_output, None

    def set_parameters(self, torch_tgt_module):
        assert isinstance(torch_tgt_module, nn.MultiheadAttention)
        assert self.embed_dim == torch_tgt_module.embed_dim
        assert self.batch_first == torch_tgt_module.batch_first
        assert self.dropout == torch_tgt_module.dropout
        assert self.head_dim == torch_tgt_module.head_dim
        assert self.num_heads == torch_tgt_module.num_heads
        assert self.kdim == torch_tgt_module.kdim
        assert self.vdim == torch_tgt_module.vdim
        self.qkv.weight.data = torch_tgt_module.in_proj_weight.data
        self.qkv.bias.data = torch_tgt_module.in_proj_bias.data
        self.proj.weight.data = torch_tgt_module.out_proj.weight.data
        self.proj.bias.data = torch_tgt_module.out_proj.bias.data
        if self.add_bias_kv:
            self.bias_k.data = torch_tgt_module.bias_k.data
            self.bias_v.data = torch_tgt_module.bias_v.data


class ImageBindMultiheadAttention(PlainMultiHeadAttention):
    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
        return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]

if __name__ == '__main__':
    # model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
    #     'hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
    # model, _, _ = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion2b_s34b_b88k')

    # visual_model = model.visual
    # img = torch.rand(2, 3, 224, 224)
    # result = visual_model(img)[0].cpu().detach().clone()
    # for module in visual_model.transformer.resblocks:
    #     new_module = PlainMultiHeadAttention()
    #     new_module.set_parameters(module.attn)
    #     module.attn = new_module
    # result2 = visual_model(img)[0].cpu().detach().clone()
    # print(torch.allclose(result, result2))

    ImageBind = imagebind_model.imagebind_huge(pretrained=True)
    preprocessor = ImageBind.modality_preprocessors['audio']
    backbone = ImageBind.modality_trunks['audio']
    for module in backbone.blocks:
        new_module = ImageBindMultiheadAttention(add_bias_kv=True)
        new_module.set_parameters(module.attn)
        module.attn = new_module
    
    from peft import LoraModel, LoraConfig
    config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["qkv"],
    )
    lora_model = LoraModel(backbone, config, "default")
    for n, p in lora_model.named_parameters():
        if p.requires_grad:
            print(n)
    