import torch
from torch.nn import functional as F
from torch import Tensor
# TODO: 要把原矩阵(self.W)的参数也保存下来，并且得把它存在cpu上，减少显存的占用
class ProbLinear(torch.nn.Module):
    def __init__(self, 
                 in_features:  int,
                 rank:         int, 
                 out_features: int,
                 weight,
                 bias:         bool,
                 device=None,
                 dtype=None) -> None:
    
        factory_kwargs = {"device": device, "dtype": dtype}

        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank

        self.V = torch.nn.Parameter(torch.empty((rank, in_features), **factory_kwargs))
        self.U = torch.nn.Parameter(torch.empty((out_features, rank), **factory_kwargs))
        self.S = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        self.weight = weight.to(device='cpu', dtype=dtype)
        if bias:
            self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter("bias", None)
    def to(self, *args, **kwargs):
        # 调用父类的 to() 方法转移其他参数
        self = super().to(*args, **kwargs)
        # 显式保持 self.weight 在 CPU
        self.weight = self.weight.cpu()
        return self

    def adjust_rank(self, target_rank, lrc_V, lrc_U, sparse_comp, device, dtype):
        self.rank = target_rank
        self.V = torch.nn.Parameter(torch.empty((target_rank, self.in_features), device=device, dtype=dtype))
        self.V.data = lrc_V.clone().to(dtype = dtype)
        self.U = torch.nn.Parameter(torch.empty((self.out_features, target_rank), device=device, dtype=dtype))
        self.U.data = lrc_U.clone().to(dtype = dtype)
        self.S.data = sparse_comp.clone().to(dtype = dtype)
    
    def forward(self, x:Tensor) -> Tensor:
        return F.linear(x, self.S, self.bias) + F.linear(F.linear(x, self.V), self.U)
    
    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, rank={self.rank}, out_features={self.out_features}, bias={self.bias is not None}"

class ProbQKV(torch.nn.Module):
    def __init__(self, 
                 in_features:  int,
                 q_rank:       int, 
                 q_out:       int,
                 k_rank:       int, 
                 k_out:       int,
                 v_rank:       int, 
                 v_out:       int,
                 qkv_weight,
                 bias:         bool,
                 device=None,
                 dtype=None) -> None:
    
        factory_kwargs = {"device": device, "dtype": dtype}

        super().__init__()

        self.in_features = in_features
        self.q_rank = q_rank
        self.k_rank = k_rank
        self.v_rank = v_rank

        self.q_out = q_out
        self.k_out = k_out
        self.v_out = v_out
        
        
        self.q_V = torch.nn.Parameter(torch.empty((q_rank, in_features), **factory_kwargs))
        self.q_U = torch.nn.Parameter(torch.empty((self.q_out, q_rank), **factory_kwargs))
        self.q_S = torch.nn.Parameter(torch.empty((self.q_out, in_features), **factory_kwargs))


        self.k_V = torch.nn.Parameter(torch.empty((k_rank, in_features), **factory_kwargs))
        self.k_U = torch.nn.Parameter(torch.empty((self.k_out, k_rank), **factory_kwargs))
        self.k_S = torch.nn.Parameter(torch.empty((self.k_out, in_features), **factory_kwargs))


        self.v_V = torch.nn.Parameter(torch.empty((v_rank, in_features), **factory_kwargs))
        self.v_U = torch.nn.Parameter(torch.empty((self.v_out, v_rank), **factory_kwargs))
        self.v_S = torch.nn.Parameter(torch.empty((self.v_out, in_features), **factory_kwargs))
        
        self.weight = qkv_weight.to(device='cpu', dtype=dtype)

        if bias:
            self.bias = torch.nn.Parameter(torch.empty(self.q_out+self.k_out+self.v_out, **factory_kwargs))
        else:
            self.register_parameter("bias", None)

    def to(self, *args, **kwargs):

        self = super().to(*args, **kwargs)
        # self.weight = self.weight.cpu()
        return self

    def adjust_rank(self, name, target_rank, lrc_V, lrc_U, sparse_comp, device, dtype):
        if name == 'q_proj':
            self.q_rank = target_rank
            self.q_V = torch.nn.Parameter(torch.empty((target_rank, self.in_features), device=device, dtype=dtype))
            self.q_V.data = lrc_V.clone().to(dtype = dtype)
            self.q_U = torch.nn.Parameter(torch.empty((self.q_out, target_rank), device=device, dtype=dtype))
            self.q_U.data = lrc_U.clone().to(dtype = dtype)
            self.q_S.data = sparse_comp.clone().to(dtype = dtype)
        elif name == 'k_proj':
            self.k_rank = target_rank
            self.k_V = torch.nn.Parameter(torch.empty((target_rank, self.in_features), device=device, dtype=dtype))
            self.k_V.data = lrc_V.clone().to(dtype = dtype)
            self.k_U = torch.nn.Parameter(torch.empty((self.k_out, target_rank), device=device, dtype=dtype))
            self.k_U.data = lrc_U.clone().to(dtype = dtype)
            self.k_S.data = sparse_comp.clone().to(dtype = dtype)
        elif name == 'v_proj':
            self.v_rank = target_rank
            self.v_V = torch.nn.Parameter(torch.empty((target_rank, self.in_features), device=device, dtype=dtype))
            self.v_V.data = lrc_V.clone().to(dtype = dtype)
            self.v_U = torch.nn.Parameter(torch.empty((self.v_out, target_rank), device=device, dtype=dtype))
            self.v_U.data = lrc_U.clone().to(dtype = dtype)
            self.v_S.data = sparse_comp.clone().to(dtype = dtype)

    def forward(self, x:Tensor) -> Tensor:
        query = F.linear(x, self.q_S) + F.linear(F.linear(x, self.q_V), self.q_U)
        key = F.linear(x, self.k_S) + F.linear(F.linear(x, self.k_V), self.k_U)
        value = F.linear(x, self.v_S) + F.linear(F.linear(x, self.v_V), self.v_U)

        return torch.cat([query, key, value], dim=2) + self.bias if self.bias is not None else torch.cat([query, key, value], dim=2)
    
    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, q_out={self.q_out}, q_rank={self.q_rank}, k_out={self.k_out}, k_rank={self.k_rank}, v_out={self.v_out}, v_rank={self.v_rank}, bias={self.bias is not None}"