import torch.nn as nn
import torch
import torch.nn.functional as F
import math

def get_layer(config,layer_type,input_size,output_size,rank,omega,scale):
    if layer_type == 'linear':
        return LinearLayer(input_size,output_size)
    elif layer_type == 'low_rank':
        return LowrankLayer(config,input_size,output_size,rank,omega,scale)


class LinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearLayer,self).__init__()
        self.linear = nn.Linear(input_dim,output_dim)

    def forward(self,x):
        return self.linear(x)
    
class LowrankLayer(nn.Module):
    def __init__(self, input_dim,output_dim,rank,omega,scale):
        super(LowrankLayer,self).__init__()
        self.k = rank
        self.omega = omega
        self.scale = scale
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.A = nn.Parameter(torch.empty(output_dim,self.k))
        self.B = nn.Parameter(torch.empty(self.k,input_dim))
        nn.init.kaiming_uniform_(self.A,a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.B,a=math.sqrt(5))
        self.bias = nn.Parameter(torch.empty(output_dim))
        stdv = 1. / math.sqrt(input_dim)
        nn.init.uniform_(self.bias,-stdv, stdv)
    def forward(self,x):
        if self.omega!=0:
            sin_AB = torch.sin(self.omega*self.A@self.B)/self.scale
            return F.linear(x,sin_AB,self.bias)
        else:
            return F.linear(x,self.A@self.B,self.bias)
        
