# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import constant_init, kaiming_init, trunc_normal_
from mmseg.registry import MODELS
from ..utils import PatchEmbed, resize
from .vit import VisionTransformer

from .lora_layers import LoRALinear


class LoRAMultiheadAttention(nn.Module):
    def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., bias=True, batch_first=True, lora_cfg=None):
        super().__init__()
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.batch_first = batch_first

        self.q = nn.Linear(embed_dims, embed_dims, bias=bias) if lora_cfg is None else LoRALinear(embed_dims, embed_dims, **lora_cfg)
        self.k = nn.Linear(embed_dims, embed_dims, bias=bias) if lora_cfg is None else LoRALinear(embed_dims, embed_dims, **lora_cfg)
        self.v = nn.Linear(embed_dims, embed_dims, bias=bias) if lora_cfg is None else LoRALinear(embed_dims, embed_dims, **lora_cfg)

        self.attn = nn.MultiheadAttention(embed_dim=embed_dims, num_heads=num_heads, dropout=attn_drop, bias=bias, batch_first=batch_first)
        self.proj = nn.Linear(embed_dims, embed_dims)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, identity=None):
        B, N, C = x.shape
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        out, _ = self.attn(q, k, v, need_weights=False)
        out = self.proj(out)
        out = self.proj_drop(out)
        return identity + out if identity is not None else x + out


class LoRATransformerEncoderLayer(BaseModule):
    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=True,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 batch_first=True,
                 with_cp=False,
                 lora_cfg=None):
        super().__init__()
        self.norm1_name, norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)
        self.add_module(self.norm1_name, norm1)

        self.attn = LoRAMultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            bias=qkv_bias,
            batch_first=batch_first,
            lora_cfg=lora_cfg
        )

        self.norm2_name, norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)
        self.add_module(self.norm2_name, norm2)

        self.ffn = FFN(
            embed_dims=embed_dims,
            feedforward_channels=feedforward_channels,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) if drop_path_rate > 0 else None,
            act_cfg=act_cfg)

        self.with_cp = with_cp

    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    def forward(self, x):
        def _inner_forward(x):
            x = self.attn(self.norm1(x), identity=x)
            x = self.ffn(self.norm2(x), identity=x)
            return x

        if self.with_cp and x.requires_grad:
            x = cp.checkpoint(_inner_forward, x)
        else:
            x = _inner_forward(x)
        return x


@MODELS.register_module()
class VisionTransformer_LoRA(VisionTransformer):
    def __init__(self, *args, lora_cfg=None, **kwargs):
        num_layers = kwargs.get('num_layers', 12)
        drop_path_rate = kwargs.get('drop_path_rate', 0.)
        drop_rate = kwargs.get('drop_rate', 0.)
        attn_drop_rate = kwargs.get('attn_drop_rate', 0.)
        num_fcs = kwargs.get('num_fcs', 2)
        qkv_bias = kwargs.get('qkv_bias', True)
        act_cfg = kwargs.get('act_cfg', dict(type='GELU'))
        norm_cfg = kwargs.get('norm_cfg', dict(type='LN'))
        with_cp = kwargs.get('with_cp', False)
        lora_cfg = kwargs.get('lora_cfg', dict(r=4, lora_alpha=16, lora_dropout=0.1))
        
        super().__init__(*args, **kwargs)
        # dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.num_layers)]
        self.num_layers = num_layers
        self.drop_path_rate = drop_path_rate
        self.drop_rate = drop_rate
        self.attn_drop_rate = attn_drop_rate
        self.num_fcs = num_fcs
        self.qkv_bias = qkv_bias
        self.act_cfg = act_cfg
        self.norm_cfg = norm_cfg
        self.with_cp = with_cp
        self.embed_dims = kwargs.get('embed_dims', 768)
        self.num_heads = kwargs.get('num_heads', 12)
        self.mlp_ratio = kwargs.get('mlp_ratio', 4)
        
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
        new_layers = ModuleList()
        for i in range(self.num_layers):
            new_layers.append(
                LoRATransformerEncoderLayer(
                    embed_dims=self.embed_dims,
                    num_heads=self.num_heads,
                    feedforward_channels=self.mlp_ratio * self.embed_dims,
                    drop_rate=self.drop_rate,
                    attn_drop_rate=self.attn_drop_rate,
                    drop_path_rate=dpr[i],
                    num_fcs=self.num_fcs,
                    qkv_bias=self.qkv_bias,
                    act_cfg=self.act_cfg,
                    norm_cfg=self.norm_cfg,
                    with_cp=self.with_cp,
                    lora_cfg=lora_cfg
                )
            )
        self.layers = new_layers