import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
import numpy as np

from utils.conv_type import ConvMask, ConvMaskMW




def l1_reg_loss(model):
    loss = 0
   
    for n, m in model.named_modules():
        if isinstance(m, (ConvMaskMW)):
            loss += m.m.abs().sum() #L1 reg on mask variable only which is not the same as in CS
        
        if isinstance(m, (ConvMask)):
            loss += m.weight.abs().sum() # L1 reg on the weights
            
    return loss

def Total_reg_loss(model):
    loss = 0
    
    for n, m in model.named_modules():
        if isinstance(m, (ConvMaskMW)):
            loss += (m.m * m.m + m.w * m.w).sum() #L2 on overparameterization
            
    print(loss)        
    return loss

def loss_l2_nobn(model):
    """
    Computes the L2 norm of all weights without the batch norm parameters
    """
    loss_val = 0
    total = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMaskMW)):
            loss_val += ((m.m * m.w)**2).sum()
    
    return loss_val

def loss_l1_nobn(model):
    """
    Computes the L1 norm of all weights without the batch norm parameters
    """
    loss_val = 0
    total = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMaskMW)):
            loss_val += ((m.m * m.w).abs()).sum()

    return loss_val

def Regularization(model, reg : str):
    "Adds the correct regularization to the optimization problem"

    if reg == 'l1_reg_loss':
        return l1_reg_loss(model)

    elif reg == 'Total_reg_loss':
        return Total_reg_loss(model)
    elif reg == 'loss_l2_nobn':
        return loss_l2_nobn(model)
    elif reg == 'loss_l1_nobn':
        return loss_l1_nobn(model)
    else:
        print('No additonal regularization specified')
        return 0
