from utils.adv_generator import inf_generator
from tqdm import tqdm
import copy
import torch.nn.utils.prune as prune
import torch
from torch import nn
from methods.method import Method
from utils.utils import Statistics

#Not using pruning, but using L1-sparse Finetuning as method (as Salun reported)
#code modified from: https://github.com/OPTML-Group/Unlearn-Saliency/blob/master/Classification/unlearn/FT.py#L145
class SPARSE(Method):

    def set_hyperparameters(self, args): 
        self.no_l1_epochs = 0  #from code
        self.alpha = 1e-3 if args.data_name != 'imagenet' else 1e-5 #[10−5, 10−1] Line search


    def unlearn(self, model, loaders, args):
        model.train() 
        criterion = nn.CrossEntropyLoss()
        optimizer = self.get_optimizer(model)

        for epoch in range(args.remain_epochs):
            for i, (image, target) in enumerate(tqdm(self.train_remain_loader)):

                image = image.cuda()
                target = target.cuda()

                if epoch < args.remain_epochs - self.no_l1_epochs:
                    current_alpha = self.alpha * (
                        1 - epoch / (args.remain_epochs - self.no_l1_epochs)
                    )
                else:
                    current_alpha = 0
                
                # compute output
                output_clean = model(image)
                loss = criterion(output_clean, target)
                l1_loss = current_alpha * self.l1_regularization(model)
                # print("L1 Loss", l1_loss.item())
                loss += l1_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        remain_rate = self.check_sparsity(model)
        print(f'remain weight rate: {remain_rate}')

        return model
    
    def l1_regularization(self, model):
        params_vec = []
        for param in model.parameters():
            params_vec.append(param.view(-1))
        return torch.linalg.norm(torch.cat(params_vec), ord=1)

    def check_sparsity(self, model):
        #code from https://github.com/OPTML-Group/Unlearn-Sparse/
        sum_list = 0
        zero_sum = 0

        for name, m in model.named_modules():
            if isinstance(m, nn.Conv2d):
                sum_list = sum_list + float(m.weight.nelement())
                zero_sum = zero_sum + float(torch.sum(m.weight == 0))

        if zero_sum:
            remain_weight_ratie = 100 * (1 - zero_sum / sum_list)
            print("* remain weight ratio = ", f"{100 * (1 - zero_sum / sum_list):.3f}", "%")
        else:
            print("no weight for calculating sparsity")
            remain_weight_ratie = None

        return remain_weight_ratie