
class GenTheta(torch.autograd.Function):
    
    @staticmethod    
    def forward(ctx, A, c, out_dim, in_dim, seed): 

        B = A.shape[0]
        T = torch.zeros(out_dim, in_dim).to(A.device)
        l = B//c 
        rand_seed = torch.randint(int(1e10), (1,))
        torch.manual_seed(seed)

        for i in range(c):  
            A_ = A[i*l:(i+1)*l]
            W_ = torch.zeros(l, out_dim, in_dim, device=A.device)
            
            init.kaiming_uniform_(W_, a=math.sqrt(5))

            T += torch.einsum('b,boi->oi', A_, W_)
        
        params = torch.autograd.Variable(torch.tensor([c, out_dim, in_dim, seed]))
        ctx.save_for_backward(A, params)
        torch.manual_seed(rand_seed)

        return T
    
    @staticmethod
    def backward(ctx, grad_output):
        A, params = ctx.saved_tensors
        B = A.shape[0]
        c, out_dim, in_dim, seed = params
        rand_seed = torch.randint(int(1e10), (1,))
        torch.manual_seed(seed)
        DA = torch.empty(0).to(grad_output.device)
        l = torch.div(B, c, rounding_mode='floor')

        for i in range(c):  
            W_ = torch.zeros(l, out_dim, in_dim, device=A.device)
            
            init.kaiming_uniform_(W_, a=math.sqrt(5))

            W_ = W_.permute(1,2,0)
            DA = torch.cat((DA, torch.einsum('d,dl->l', grad_output.flatten(), W_.reshape(-1,l))) , dim=0) 

            
        torch.manual_seed(rand_seed)    
        return DA , None, None, None, None




class PRANC(nn.Module):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        ka = 1024,  
        **kwargs
    ):
        super(PRANC, self).__init__(**kwargs)

        self.in_features = in_features
        self.out_features = out_features
        self.ka = ka
        self.N = in_features*out_features 

        self.genTetha = GenTheta.apply
        
        self.seed = nn.Parameter(torch.randint(int(1e10), (1,)), requires_grad=False)
        self.alpha = nn.Parameter(torch.ones(ka), requires_grad=True)
        
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, num_basis={}'.format(
            self.in_features, self.out_features, self.num_basis)

    def forward(self, x=None):
        
        weight = self.genTetha(self.alpha, 1, self.out_features, self.in_features, self.seed)
        weight = weight.transpose(-2,-1)
        if x is not None: 
            return x @ weight
        return weight
    


    

class NOLA(nn.Module):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        ka = 1024, 
        kb = 1024,
        rank=16, 
        d = 128, 
        c = 1.0, 
        **kwargs
    ):
        super(NOLA, self).__init__(**kwargs)

        self.in_features = in_features
        self.out_features = out_features
        self.ka = ka
        self.kb = kb
        self.N = in_features*out_features 
        self.rank = rank
        self.d = d
        self.D = self.N//self.d

        self.c = c
        self.scale = self.c / self.rank
        self.genTetha = GenTheta.apply

        
        self.seed_a = nn.Parameter(torch.randint(int(1e10), (1,)), requires_grad=False)
        self.seed_b = nn.Parameter(torch.randint(int(1e10), (1,)), requires_grad=False)

        self.alpha = nn.Parameter(torch.ones(ka), requires_grad=True)
        self.beta = nn.Parameter(torch.ones(kb), requires_grad=True)
        
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, num_basis={}'.format(
            self.in_features, self.out_features, self.num_basis)

    def forward(self, x=None):
        
        A = self.genTetha(self.alpha, 1, self.out_features, self.rank, self.seed_a)
        B = self.genTetha(self.beta, 1, self.rank, self.in_features, self.seed_b)
        
        weight = self.scale*(A @ B).transpose(-2,-1)
        weight = weight.reshape(self.in_features, self.out_features)
        
        if x is not None: 
            return x @ weight
        return weight



