import math
import torch
import torch.nn as nn


class Adapter(nn.Module):
    def __init__(self,
                 d_model: int,
                 bottleneck: int = 64,
                 dropout: float = 0.0,
                 init_option: str = "adapter",  # "adapter" | "lora"
                 adapter_scalar: str = "1.0",   # "1.0" | "learnable_scalar"
                 adapter_layernorm_option: str = "in",
                 lora_r: int = 8):  # LoRA rank
        super().__init__()
        self.d_model = d_model
        self.bottleneck = bottleneck
        self.init_option = init_option.lower()
        self.dropout = dropout

        self.adapter_layernorm_option = adapter_layernorm_option
        self.adapter_layer_norm_before = None
        if adapter_layernorm_option in ["in", "out"]:
            self.adapter_layer_norm_before = nn.LayerNorm(d_model)

        if adapter_scalar == "learnable_scalar":
            self.scale = nn.Parameter(torch.ones(1))
        else:
            self.scale = float(adapter_scalar)

        if self.init_option == "adapter":
            self.down_proj = nn.Linear(d_model, bottleneck)
            self.non_linear_func = nn.ReLU()
            self.up_proj = nn.Linear(bottleneck, d_model)

            with torch.no_grad():
                nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
                nn.init.zeros_(self.up_proj.weight)
                nn.init.zeros_(self.down_proj.bias)
                nn.init.zeros_(self.up_proj.bias)

        elif self.init_option == "lora":
            self.lora_A = nn.Linear(d_model, lora_r, bias=False)
            self.lora_B = nn.Linear(lora_r, d_model, bias=False)

            with torch.no_grad():
                nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
                nn.init.zeros_(self.lora_B.weight)
        else:
            raise ValueError(f"Unknown init_option={init_option}")

    def forward(self, x):
        if self.adapter_layernorm_option == "in" and self.adapter_layer_norm_before is not None:
            self.adapter_layer_norm_before = self.adapter_layer_norm_before.to(x.dtype)
            x = self.adapter_layer_norm_before(x)

        if self.init_option == "adapter":
            down = self.down_proj(x)
            down = self.non_linear_func(down)
            down = nn.functional.dropout(down, p=self.dropout, training=self.training)
            down = torch.clamp(down, min=-1e4, max=1e4)
            out = self.up_proj(down)

        elif self.init_option == "lora":
            out = self.lora_B(self.lora_A(x))

        else:
            raise ValueError(f"Unknown init_option={self.init_option}")

        out = out * self.scale

        if self.adapter_layernorm_option == "out" and self.adapter_layer_norm_before is not None:
            self.adapter_layer_norm_before = self.adapter_layer_norm_before.to(x.dtype)
            out = self.adapter_layer_norm_before(out)

        return out