from tqdm import tqdm
import torch
from torch import nn
import numpy as np
from methods.method import Method
from utils.backbone import get_model
from utils.utils import resetFinalResnet

class EUK(Method):
    # Goel et al. Evaluating inexact unlearning requires revisiting forgetting. Arxiv 2022
    # code from https://github.com/meghdadk/SCRUB/blob/main/MIA_experiments.ipynb
    def set_hyperparameters(self, args):
        self.k = 10
        self.lr_decay_epochs = [10,15,20]
        self.sgda_learning_rate = args.lr
        self.lr_decay_rate = 0.1 # not sure

    def unlearn(self, model, loaders, args):
        # freeze all layers
        for param in model.parameters():
            param.requires_grad_(False)

        model = resetFinalResnet(model, self.k, reinit=True)

        for name, param in model.named_parameters():
            print(name, param.requires_grad)     

        device = args.device

        criterion = nn.CrossEntropyLoss()
        optimizer = self.get_optimizer(model)
        
        model.train()
        for epoch in range(args.remain_epochs):
            for x_remain, y_remain in self.train_remain_loader:

                x_remain = x_remain.to(device)
                y_remain = y_remain.to(device)
                
                logits = model(x_remain)
                self.statistics.add_forward_flops(x_remain.size(0))
                
                ce_loss = criterion(logits, y_remain)
                
                model.zero_grad()
                optimizer.zero_grad()

                loss = ce_loss

                loss.backward()
                self.statistics.add_backward_flops(x_remain.size(0))
                optimizer.step()

            new_lr = self.adjust_learning_rate(epoch+1, optimizer)
            optimizer.param_groups[0]['lr'] = new_lr
            print(f'Epoch: {epoch+1}, LR: {new_lr}')
        
        return model

    def adjust_learning_rate(self, epoch, optimizer):
        """Sets the learning rate to the initial LR decayed by decay rate every steep step"""
        steps = np.sum(epoch > np.asarray(self.lr_decay_epochs))
        new_lr = self.sgda_learning_rate
        if steps > 0:
            new_lr = self.sgda_learning_rate * (self.lr_decay_rate ** steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr
        return new_lr
    
class EU1(EUK):
    def set_hyperparameters(self, args):
        super().set_hyperparameters(args)
        self.k = 1

class EU2(EUK):
    def set_hyperparameters(self, args):
        super().set_hyperparameters(args)
        self.k = 2

class EU3(EUK):
    def set_hyperparameters(self, args):
        super().set_hyperparameters(args)
        self.k = 3

class EU4(EUK):
    def set_hyperparameters(self, args):
        super().set_hyperparameters(args)
        self.k = 4

class EU5(EUK):
    def set_hyperparameters(self, args):
        super().set_hyperparameters(args)
        self.k = 5

class EU10(EUK):
    def set_hyperparameters(self, args):
        super().set_hyperparameters(args)
        self.k = 10

class EU20(EUK):
    def set_hyperparameters(self, args):
        super().set_hyperparameters(args)
        self.k = 20