import torch
import torch.nn as nn
import math

class LoraPool(nn.Module):
    def __init__(self, pool_size, dim, rank, lora_alpha=1):
        super().__init__()
        self.r = rank
        self.lora_alpha = lora_alpha
        self.scaling = torch.tensor(self.lora_alpha / self.r)
        self.pool_size = pool_size
        self.dim = dim

        self.create_parameters()
        self.reset_parameters()

    def create_parameters(self):
        attributes = ['k_lora', 'v_lora']
        for attr_name in attributes:
                setattr(self, attr_name+'_A', nn.Parameter(torch.zeros((self.pool_size, self.dim, self.r))))
                setattr(self, attr_name+'_B', nn.Parameter(torch.zeros((self.pool_size, self.r, self.dim))))

    def reset_parameters(self):
        params = ['k_lora_A', 'k_lora_B', 'v_lora_A', 'v_lora_B']
        for param_name in params:
            param = getattr(self, param_name)
            if isinstance(param, nn.Parameter):
                if param_name.endswith('_A'):
                    p, _, _ = param.shape
                    for i in range(p):
                        nn.init.kaiming_uniform_(param[i], a=math.sqrt(5))
                else:
                    nn.init.zeros_(param)

    def forward(self, k, v, task_id=-1):
        if task_id == -1:
            return 0, 0
        self.to(dtype=k.dtype, device=k.device)
        assert isinstance(task_id, int)
        k_w = self.k_lora_A[task_id] @ self.k_lora_B[task_id] * self.scaling
        v_w = self.v_lora_A[task_id] @ self.v_lora_B[task_id] * self.scaling
        k_lora = torch.einsum('bld,dz->blz', k, k_w)
        v_lora = torch.einsum('bld,dz->blz', v, v_w)
        return k_lora, v_lora