import math
import warnings
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from spec_benchmark.Engine.models.base import LoRAConfig
from spec_benchmark.profiler import bucket_timer


class GatedLoRALinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        lora_config: LoRAConfig,
        bias: bool = False,
        device: torch.device = None,
        dtype: torch.dtype = None,
    ):
        super().__init__()
        assert lora_config.lora_bias == False, "lora_bias is not supported for GatedLoRALinear"

        self.in_features = in_features
        self.out_features = out_features
        self.lora_rank = lora_config.rank
        self.lora_scaling = lora_config.lora_scaling

        self.base_layer = nn.Linear(in_features, out_features, bias=bias)        
        self.lora_A = nn.Linear(in_features, self.lora_rank, bias=False, device=device, dtype=dtype)
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))

        self.register_buffer("lora_BT", torch.zeros(self.lora_rank, out_features, device=device, dtype=dtype), persistent=True)
        self._register_load_state_dict_pre_hook(self.load_hook)


    def load_hook(self, state_dict, prefix, *args, **kwargs):
        if prefix + "weight" in state_dict:
            W = state_dict.pop(prefix + "weight")
            state_dict[prefix + "base_layer.weight"] = W
        if prefix + "bias" in state_dict:
            b = state_dict.pop(prefix + "bias")
            state_dict[prefix + "base_layer.bias"] = b
        if prefix + "lora_B.weight" in state_dict:
            W = state_dict.pop(prefix + "lora_B.weight")
            state_dict[prefix + "lora_BT"] = W.t().contiguous()
    

    def forward(self, x: torch.Tensor, gate_mask: Optional[torch.Tensor] = None):
        base_bucket = getattr(self, "_prof_base_bucket", None)
        lora_bucket = getattr(self, "_prof_lora_bucket", None)

        with bucket_timer(base_bucket):
            y = self.base_layer(x)
        if gate_mask is None: return y

        with bucket_timer(lora_bucket):
            z = self.lora_A(x)
            z.mul_(gate_mask) # gating on subspace

            out_local = y.size(-1)
            y2d, z2d = y.reshape(-1, out_local), z.reshape(-1, self.lora_rank)
            y2d.addmm_(z2d, self.lora_BT, alpha=self.lora_scaling)
            y = y2d.view_as(y)
        return y