import torch.nn as nn
import torch




class Architecture_s(nn.Module):
    def __init__(self, layer_num, args):
        super(Architecture_s, self).__init__()
        alpha = torch.randn((layer_num, 2, args.lora_r))
        nn.init.normal_(alpha, mean=0.0, std=args.arch_std)
        self.alpha = nn.Parameter(alpha)
        self.lamb = args.lamb
        self.r = args.lora_r
        self.layer_num = layer_num
        
    def regularizer(self):
        ones = self.log_alpha.new_ones((self.layer_num, 2, self.r))
        softmax_alpha = torch.softmax(self.alpha, dim=-1)
        entropy = -self.lamb * torch.sum((ones-softmax_alpha) * torch.log(ones-softmax_alpha) + softmax_alpha * torch.log(softmax_alpha))
        return entropy
        
    def forward(self):
        softmax_alpha = torch.softmax(self.alpha, dim=-1)
        return softmax_alpha
    
class Architecture_n(nn.Module):
    def __init__(self, layer_num, args):
        super(Architecture_n, self).__init__()
        alpha = torch.randn((layer_num, 2, args.lora_r))
        nn.init.normal_(alpha, mean=0.0, std=args.arch_std)
        self.alpha = nn.Parameter(alpha)
        self.lamb = args.lamb
        self.r = args.lora_r
        self.layer_num = layer_num
        
    def forward(self):
        return self.alpha
    
class Architecture_o(nn.Module):
    def __init__(self, layer_num, args):
        super(Architecture_o, self).__init__()
        self.lamb = args.lamb
        self.bottom = args.cl_bottom
        self.r = args.lora_r
        self.layer_num = layer_num
        alpha = torch.randn((layer_num, 2, self.r))
        nn.init.normal_(alpha, mean=0.5, std=args.arch_std)
        self.log_alpha = nn.Parameter(torch.clamp(alpha, self.bottom, 1-self.bottom).log())


    def regularizer(self):
        ones = self.log_alpha.new_ones((self.layer_num, 2, self.r))
        alpha = torch.clamp(self.log_alpha.exp(), self.bottom, 1-self.bottom)
        #alpha = self.log_alpha.exp()
        entropy = -self.lamb * torch.sum((ones-alpha) * torch.log(ones-alpha) + alpha * torch.log(alpha))
        return entropy
        
    def forward(self):
        #return torch.clamp(self.log_alpha.exp(), self.bottom, 1-self.bottom)
        return self.log_alpha.exp()