import torch
from torch import nn
import torch.nn.functional as F

class nolora_vanilla(nn.Module):
    """
    Vanilla implementation of MLP-LoRA with no intermediate layers, no residual, depth = 2 (A and B)
    Uses GELU as the default non-linear activation
    args:
        dim: hidden dimension (a.k.a. rank)
        out_dim: output dimension
    """
    def __init__(self, dim=32, out_dim=768):
        super().__init__()
        self.non_linear = nn.GELU()
        self.A = nn.Linear(out_dim, dim, bias=False)
        self.B = nn.Linear(dim, out_dim, bias=False)
        nn.init.zeros_(self.B.weight)
    
    def forward(self, x):
        x = self.A(x)  # Linear transformation A
        x = self.non_linear(x)  # GELU activation
        x = self.B(x)  # Linear transformation B
        return x

class nolora_multilayers_depth_6(nn.Module):
    """
    Simpler naive implementation of MLP-LoRA with multilayers.
    We actually use this version in the paper for the analysis, and it DOES seem to perform better.
    However, to easily stack up multiple layers flexibly, a more general implementation is provided (see nolora_multilayers below).
    """
    def __init__(self, dim=32, out_dim=768):
        super().__init__()
        self.non_linear = nn.GELU()
        self.A = nn.Linear(out_dim, dim, bias=False)
        self.i1 = nn.Linear(dim, dim, bias=False)
        self.i2 = nn.Linear(dim, dim, bias=False)
        self.i3 = nn.Linear(dim, dim, bias=False)      
        self.i4 = nn.Linear(dim, dim, bias=False) # 4 intermediate layers      
        self.B = nn.Linear(dim, out_dim, bias=False)
        nn.init.zeros_(self.B.weight)

    def forward(self, x):
        x = self.A(x)  # Linear transformation A
        x = self.non_linear(x)  # GELU activation
        residual = x.clone() 
        x = self.i1(x)
        x = self.non_linear(x)
        residual_ = x.clone()
        x = self.i2(x)
        x = self.non_linear(x)
        x = self.i3(x)
        x = self.non_linear(x)
        x = x + residual_ 
        x = self.i4(x)
        x = self.non_linear(x)
        x = x + residual
        x = self.B(x)  # Linear transformation B
        return x

class nolora_multilayers(nn.Module):
    """
    Implementation of MLP-LoRA with residual and intermediate layers
    Uses GELU as the default non-linear activation
    args:
        dim: hidden dimension (a.k.a. rank)
        out_dim: output dimension
        depth: number of adapters in MLP-LoRA (A, B and all the intermediate layers)
    """
    def __init__(self, dim=32, out_dim=768, depth=4):
        super().__init__()
        self.non_linear = nn.GELU()
        self.depth = depth
        self.res_num = (depth - 2) // 2
        self.A = nn.Linear(out_dim, dim, bias=False)
        num_of_intermediate_layers = self.depth - 2
        assert (num_of_intermediate_layers % 2 == 0) # make sure residual works
        self.encoders = nn.ModuleList([nn.Linear(dim, dim, bias=False) for _ in range(num_of_intermediate_layers // 2)])
        self.decoders = nn.ModuleList([nn.Linear(dim, dim, bias=False) for _ in range(num_of_intermediate_layers // 2)])
        self.B = nn.Linear(dim, out_dim, bias=False)
        nn.init.zeros_(self.B.weight)

    def forward(self, x):
        x = self.A(x)  # Linear transformation A
        x = self.non_linear(x)  # GELU activation
        residuals = []
        for i in range(self.res_num):
            residuals.append(x.clone())
            x = self.encoders[i](x)
            x = self.non_linear(x)
        for i in range(self.res_num):
            x = self.decoders[i](x)
            x = self.non_linear(x)
            x = x + residuals[self.res_num - 1 - i]
        x = self.B(x)  # Linear transformation B
        return x
    
def forward_attn(
        self, hidden_states, head_mask=None, output_attentions: bool = False
        ):
    mixed_query_layer = self.query(hidden_states) + 1.0 * self.q_adapter(hidden_states)

    key_layer = self.transpose_for_scores(self.key(hidden_states))
    value_layer = self.transpose_for_scores(self.value(hidden_states) + 1.0 * self.v_adapter(hidden_states))
    query_layer = self.transpose_for_scores(mixed_query_layer)

    # 使用 torch.nn.functional.scaled_dot_product_attention
    dropout_prob = getattr(self, 'attention_probs_dropout_prob', 0.0)  # 默认值为 0.0
    context_layer = torch.nn.functional.scaled_dot_product_attention(
        query_layer,
        key_layer,
        value_layer,
        head_mask,
        dropout_prob if self.training else 0.0,
        is_causal=False,
        scale=None,
    )

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(new_context_layer_shape)

    return context_layer, None

def forward_ffn(self, input):
    return F.linear(input, self.weight, self.bias) + 1.0 * self.adapter(input)
 
def set_nolora_vanilla(model, nolora_mode=1, mhsa_dim=16, ffn_dim=16):
    """
    nolora mode:
        1: vanilla nolora on mhsa q, v
        2: vanilla nolora on both mhsa q, v and ffn
    """
    for name, layer in model.named_children():
        if 'attention' in name:
            layer.attention.nolora_mode = nolora_mode
            layer.attention.q_adapter = nolora_vanilla(dim=mhsa_dim)
            layer.attention.v_adapter = nolora_vanilla(dim=mhsa_dim)
            bound_method = forward_attn.__get__(layer.attention, layer.attention.__class__)
            setattr(layer.attention, 'forward', bound_method)
        elif 'dense' in name:
            if nolora_mode == 1:
                continue
            layer.nolora_mode = nolora_mode
            layer.adapter = nolora_vanilla(dim=ffn_dim, out_dim=layer.weight.shape[0])
            bound_method = forward_ffn.__get__(layer, layer.__class__)
            setattr(layer, 'forward', bound_method)
        elif len(list(layer.children())) != 0:
            set_nolora_vanilla(layer, nolora_mode, mhsa_dim, ffn_dim)

def set_nolora_multilayers(model, nolora_mode=1, mhsa_dim=16, ffn_dim=16, depth=2):
    """
    nolora mode:
        1: nolora with more depth on mhsa q, v (qv setting)
        2: nolora with more depth on both mhsa q, v and ffn (qvmlp setting)
    """
    for name, layer in model.named_children():
        if 'attention' in name:
            layer.attention.nolora_mode = nolora_mode
            layer.attention.q_adapter = nolora_multilayers(dim=mhsa_dim, depth=depth)
            layer.attention.v_adapter = nolora_multilayers(dim=mhsa_dim, depth=depth)
            bound_method = forward_attn.__get__(layer.attention, layer.attention.__class__)
            setattr(layer.attention, 'forward', bound_method)
        elif 'dense' in name:
            if nolora_mode == 1:
                continue
            layer.nolora_mode = nolora_mode
            layer.adapter = nolora_multilayers(dim=ffn_dim, out_dim=layer.weight.shape[0])
            bound_method = forward_ffn.__get__(layer, layer.__class__)
            setattr(layer, 'forward', bound_method)
        elif len(list(layer.children())) != 0:
            set_nolora_multilayers(layer, nolora_mode, mhsa_dim, ffn_dim, depth=depth)