import random
import torch
import torch.nn as nn
from tqdm import tqdm
import time
import torch.nn.functional as F
from utils import evaluate, get_metric_scores

def pretrain(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs):
    

    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ),0

def retrain(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs):
    

    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ),0




def jit(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs):
    
    import Alg_Jit
    return Alg_Jit.jit(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs)




def amnesiac(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs):
    start = time.time()
    num_classes = kwargs['num_classes']
    lr = kwargs['lr']

    unlearning_trainset = []
    for x, _, clabel in forget_train_dl.dataset:
        unlearninglabels = list(range(num_classes))
        unlearninglabels.remove(clabel)
        unlearning_trainset.append((x, _, random.choice(unlearninglabels)))
    for x, _, y in retain_train_dl.dataset:
        unlearning_trainset.append((x, _, y))
    train_dl = torch.utils.data.DataLoader(
        unlearning_trainset, kwargs['batch_size'], pin_memory=True, shuffle=True
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    epochs = kwargs['epochs']
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch in train_dl:
            images, _, clabels = batch
            images, clabels = images.to(device), clabels.to(device)
            out = model(images)  
            loss = F.cross_entropy(out, clabels) 
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(f'Epoch:{epoch}| Loss:{train_loss:.3f}')
    
    end = time.time()

    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ), end-start

def badteacher(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs):
    
    import Alg_BT
    return Alg_BT.badteacher(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs)


def scrub(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs):
    
    import Alg_Scrub
    return Alg_Scrub.scrub(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs)

def salun(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs):
    
    import Alg_Salun
    return Alg_Salun.salun(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs)


def mumis(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs):
    
    import Alg_Mumis
    return Alg_Mumis.mumis(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs)


def finetune(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,):
    lr=kwargs['lr']
    epochs = kwargs['epochs']
    if kwargs['model_name'] == 'ViT':
        optimizer = torch.optim.Adam(model.parameters(), lr, )
    else:        
        optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=5e-4)
    scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max =  epochs * len(retain_train_dl) ) #  调整了四分之一周期的长度
    start = time.time()
    for epoch in tqdm(range(epochs)):
        train_loss = 0
        model.train()

        for (images, _, clabels) in retain_train_dl:
            images, clabels = images.to(device), clabels.to(device)
            out = model(images)
            loss = F.cross_entropy(out, clabels)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            

    end= time.time()
    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ),end-start
    
    



def ssd(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,
):
    
    import Alg_SSD
    return Alg_SSD.ssd(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs)





def RandomLabels(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,):

    start= time.time()

    lr = kwargs['lr']
    num_classes = kwargs['num_classes']
    epochs_unlearn = kwargs['epochs']


    arg_scheduler = [3]
    
    wd_unlearn = float(kwargs['wd_unlearn'])
    target_accuracy = float(kwargs['target_accuracy'])
    mode = kwargs['mode']
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=wd_unlearn)
    #scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=arg_scheduler, gamma=0.5)

    if mode == "CR":
        forget_class = kwargs['forget_class']
        class_to_remove = [kwargs['forget_class']]
    
        random_possible = torch.tensor([i for i in range(num_classes) if i not in class_to_remove]).to(device).to(torch.float32)
    else:
        random_possible = torch.tensor([i for i in range(num_classes)]).to(device).to(torch.float32)
    
    def loss_f(inputs, targets):
        outputs = model(inputs).to(device)
        random_labels = random_possible[torch.randint(low=0, high=random_possible.shape[0], size=targets.shape)].to(torch.int64).to(device)
        loss = criterion(outputs, random_labels)
        return loss

    def run():
        flag_exit = False
        model.train()
        for ep in tqdm(range(epochs_unlearn)):
            for inputs, _, targets in forget_train_dl:
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                loss = loss_f(inputs, targets)
                loss.backward()
                optimizer.step()

            with torch.no_grad():
                if ep%5==0:
                    model.eval()
                    curr_acc = evaluate(model, forget_train_dl, 'cuda')['Acc']
                    model.train()
                    print(f"ACCURACY FORGET SET: {curr_acc:.3f}, target is {target_accuracy:.3f}")
                if curr_acc < target_accuracy:
                    flag_exit = True

            if flag_exit:
                break
            #scheduler.step()
            #print('Accuracy: ',self.evalNet())
        model.eval()
        return model
    model = run()
    end= time.time()
    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ),end-start

def NegativeGradient(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,):

    start= time.time()
    lr = kwargs['lr']
    epochs_unlearn = kwargs['epochs']

    wd_unlearn = float(kwargs['wd_unlearn'])
    target_accuracy = float(kwargs['target_accuracy'])
    
    arg_scheduler = [3]
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=wd_unlearn)
    #scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=arg_scheduler, gamma=0.5)

    criterion = nn.CrossEntropyLoss()

    def loss_f(inputs, targets):
        outputs = model(inputs).to(device)
        loss = criterion(outputs, targets) * (-1)
        return loss
    
    def run():
        model.train()
        flag_exit = False
        for ep in tqdm(range(epochs_unlearn)):
            for inputs, _, targets in forget_train_dl:
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                loss = loss_f(inputs, targets)
                loss.backward()
                optimizer.step()

            with torch.no_grad():
                if ep%5==0:
                    model.eval()
                    curr_acc = evaluate(model, forget_valid_dl, 'cuda')['Acc']
                    model.train()
                    print(f"ACCURACY FORGET SET: {curr_acc:.3f}, target is {target_accuracy:.3f}")
                if curr_acc < target_accuracy:
                    flag_exit = True

            if flag_exit:
                break


        model.eval()
        return model
    model = run()
    end= time.time()
    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ),end-start


from copy import deepcopy

def scar(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,
):
    start = time.time()

    lr_unlearn  = kwargs['lr']
    dataset= kwargs['dataset']
    epochs_unlearn= kwargs['epochs']
    num_classes = kwargs['num_classes']
    mode = kwargs['mode']
    if mode == 'CR':
        class_to_remove = [kwargs['forget_class']]
    else:
        class_to_remove = None

    retain_sur = kwargs['retain_sur']
    gamma1  = float(kwargs['gamma1'])
    gamma2  = float(kwargs['gamma2'])
    delta  = float(kwargs['delta'])
    wd_unlearn = float(kwargs['wd_unlearn'])
    model_name = kwargs['model_name']
    
    lambda1= float(kwargs['lambda_1'])
    lambda2= float(kwargs['lambda_2'])
    temperature=float(kwargs['temperature'])
    target_accuracy = float(kwargs['target_accuracy'])

    if dataset == 'Cifar10' or dataset == 'Svhn':
        num_retain_samp = 5#1 for cr
    elif dataset == 'Cifar100' or dataset == 'Cifar20':
        num_retain_samp = 5#3 for cr
    elif dataset == 'TinyImagenet':
        num_retain_samp = 90
    

    def single_accuracy(net, loader,single_class=False):
        """Return accuracy on a dataset given by the data loader."""
        correct = 0
        total = 0

        total_sc = torch.zeros((opt.num_classes))
        correct_sc = torch.zeros((opt.num_classes))

        pred_all = []
        target_all = []

        for inputs, targets in loader:
            inputs, targets = inputs.to(opt.device), targets.to(opt.device)
            outputs = net(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if single_class:
                pred_all.append(predicted.detach().cpu())
                target_all.append(targets.detach().cpu())

        if single_class:
            pred_all = torch.cat(pred_all)
            target_all = torch.cat(target_all)
            for i in range(opt.num_classes):
                buff_tar = target_all[target_all==i]
                buff_pred = pred_all[target_all==i]
                total_sc[i] = buff_tar.shape[0]
                correct_sc[i] = (buff_pred == i).sum().item()
            return correct / total, correct_sc/total_sc
        else:
            return correct / total


    def cov_mat_shrinkage(cov_mat,gamma1=gamma1,gamma2=gamma2):
        I = torch.eye(cov_mat.shape[0]).to(device)
        V1 = torch.mean(torch.diagonal(cov_mat))
        off_diag = cov_mat.clone()
        off_diag.fill_diagonal_(0.0)
        mask = off_diag != 0.0
        V2 = (off_diag*mask).sum() / mask.sum()
        cov_mat_shrinked = cov_mat + gamma1*I*V1 + gamma2*(1-I)*V2
        return cov_mat_shrinked
    
    def normalize_cov(cov_mat):
        sigma = torch.sqrt(torch.diagonal(cov_mat))  # standard deviations of the variables
        cov_mat = cov_mat/(torch.matmul(sigma.unsqueeze(1),sigma.unsqueeze(0)))
        return cov_mat

    def tuckey_transf(vectors,delta=delta):
        return torch.pow(vectors,delta)
    def mahalanobis_dist(samples,samples_lab, mean,S_inv):
        #check optimized version
        diff = F.normalize(tuckey_transf(samples), p=2, dim=-1)[:,None,:] - F.normalize(mean, p=2, dim=-1)
        right_term = torch.matmul(diff.permute(1,0,2), S_inv)
        mahalanobis = torch.diagonal(torch.matmul(right_term, diff.permute(1,2,0)),dim1=1,dim2=2)
        return mahalanobis

    def distill(outputs_ret, outputs_original):

        soft_log_old = torch.nn.functional.log_softmax(outputs_original+10e-5, dim=1)
        soft_log_new = torch.nn.functional.log_softmax(outputs_ret+10e-5, dim=1)
        kl_div = torch.nn.functional.kl_div(soft_log_new+10e-5, soft_log_old+10e-5, reduction='batchmean', log_target=True)

        return kl_div


    
    def pairwise_cos_dist(x, y):
        """Compute pairwise cosine distance between two tensors"""
        x_norm = torch.norm(x, dim=1).unsqueeze(1)
        y_norm = torch.norm(y, dim=1).unsqueeze(1)
        x = x / x_norm
        y = y / y_norm
        return 1 - torch.mm(x, y.transpose(0, 1))
    
    def L2(embs_fgt,mu_distribs):
        embs_fgt = embs_fgt.unsqueeze(1)
        mu_distribs = mu_distribs.unsqueeze(0)
        dists=torch.norm((embs_fgt-mu_distribs),dim=2)
        return dists
 
    def run():
        """compute embeddings"""
        if model_name!='ViT':
            bbone = torch.nn.Sequential(*(list(model.children())[:-1] + [nn.Flatten()]))
            if model_name == 'AllCNN':
                fc = model.classifier
            else:
                fc = model.fc
        else:
            bbone = model
            fc = model.base.head
        
        original_model = deepcopy(model) # self.net
        original_model.eval()
        bbone.eval()
 
        # embeddings of retain set
        with torch.no_grad():
            ret_embs=[]
            labs=[]
            cnt=0
            for img_ret, _, lab_ret in retain_train_dl:
                img_ret, lab_ret = img_ret.to(device), lab_ret.to(device)
                
                if model_name =='ViT':
                    logits_ret = bbone.base(img_ret)
                else:
                    logits_ret = bbone(img_ret)

                ret_embs.append(logits_ret)
                labs.append(lab_ret)
                cnt+=1
            ret_embs=torch.cat(ret_embs)
            labs=torch.cat(labs)
        

        # compute distribs from embeddings
        distribs=[]
        cov_matrix_inv =[]
        for i in range(num_classes):
            if type(class_to_remove) is list:
                if i not in class_to_remove:
                    samples = tuckey_transf(ret_embs[labs==i])
                    distribs.append(samples.mean(0))
                    cov = torch.cov(samples.T)
                    cov_shrinked = cov_mat_shrinkage(cov_mat_shrinkage(cov))
                    cov_shrinked = normalize_cov(cov_shrinked)
                    cov_matrix_inv.append(torch.linalg.pinv(cov_shrinked))
            else:
                samples = tuckey_transf(ret_embs[labs==i])
                distribs.append(samples.mean(0))
                cov = torch.cov(samples.T)
                cov_shrinked = cov_mat_shrinkage(cov_mat_shrinkage(cov))
                cov_shrinked = normalize_cov(cov_shrinked)
                cov_matrix_inv.append(torch.linalg.pinv(cov_shrinked))

        distribs=torch.stack(distribs)
        cov_matrix_inv=torch.stack(cov_matrix_inv)
        
        bbone.train(), fc.train()
        import torch.optim as optim
        optimizer = optim.Adam(model.parameters(), lr=lr_unlearn, weight_decay=wd_unlearn)

        init = True
        flag_exit = False
        all_closest_class = []
       
        vec_forg=None
        if 'Tiny' in dataset:
            th = .4
            
        else:
            th = .8
            
        
        print('Num batch forget: ',len(forget_train_dl), 'Num batch retain: ',len(retain_sur))

        for epoch in tqdm(range(epochs_unlearn)):
            for n_batch, (img_fgt, _, lab_fgt) in enumerate(forget_train_dl):
                for n_batch_ret, all_batch in enumerate(retain_sur):

                    if mode == 'CR':
                        img_ret, lab_ret = all_batch
                    else:
                        img_ret, lab_ret,outputs_original = all_batch
                        outputs_original = outputs_original.to(device)
                    
                    img_ret, lab_ret,img_fgt, lab_fgt  = img_ret.to(device), lab_ret.to(device),img_fgt.to(device), lab_fgt.to(device)
                    optimizer.zero_grad()
                    if model_name =='ViT':
                        embs_fgt = bbone.forward_encoder(img_fgt)
                    else:
                        embs_fgt = bbone(img_fgt)

                    # compute Mahalanobis distance between embeddings and cluster
                    dists = mahalanobis_dist(embs_fgt,lab_fgt,distribs,cov_matrix_inv).T  

                    if init and n_batch_ret==0:
                        closest_class = torch.argsort(dists, dim=1)
                        tmp = closest_class[:, 0]
                        closest_class = torch.where(tmp == lab_fgt, closest_class[:, 1], tmp)
                        all_closest_class.append(closest_class)
                        closest_class = all_closest_class[-1]
                    else:
                        closest_class = all_closest_class[n_batch]

                    dists = dists[torch.arange(dists.shape[0]), closest_class[:dists.shape[0]]]

                    if vec_forg is not None and mode=='HR':
                        vec_forg[vec_forg<th] = 0
                        for i in range(num_classes):
                            dists[lab_fgt==i] = dists[lab_fgt==i]*(vec_forg[i])


                    loss_fgt = torch.mean(dists) * lambda1
                    
                    if model_name =='ViT':
                        outputs_ret = fc(bbone.forward_encoder(img_ret))
                    else:
                        outputs_ret = fc(bbone(img_ret))

                    if mode =='CR':
                        with torch.no_grad():
                            outputs_original = original_model(img_ret)
                            label_out = torch.argmax(outputs_original,dim=1)
                            outputs_original = outputs_original[label_out!=class_to_remove[0],:]
                            outputs_original[:,torch.tensor(class_to_remove,dtype=torch.int64)] = torch.min(outputs_original)
                        
                        outputs_ret = outputs_ret[label_out!=class_to_remove[0],:]
                    
                    loss_ret = distill(outputs_ret, outputs_original/temperature)*lambda2
                    loss=loss_ret+loss_fgt
                    
                    if n_batch_ret>num_retain_samp:
                        del loss,loss_ret,loss_fgt, embs_fgt,dists
                        break
                    
                    #print(f'n_batch_ret:{n_batch_ret} ,loss FGT:{loss_fgt}, loss RET:{loss_ret}')
                    loss.backward()
                    optimizer.step()

                    with torch.no_grad():
                        model.eval()
                        if mode=='CR':
                            curr_acc = evaluate(model, forget_train_dl, 'cuda')['Acc'] 
                        else:
                            curr_acc,vec_forg = single_accuracy(model, forget_train_dl, single_class=True)
                        model.train()
                        if curr_acc < target_accuracy and epoch>1:
                            flag_exit = True

                    if flag_exit:
                        break
                if flag_exit:
                    break

            # evaluate accuracy on forget set every batch
            with torch.no_grad():
                model.eval()
                curr_acc = evaluate(model, forget_train_dl, 'cuda')['Acc']
                test_acc =  evaluate(model, retain_valid_dl, 'cuda')['Acc']
                model.train()
                print(f"AAcc forget: {curr_acc:.3f}, target is {target_accuracy:.3f}, test is {test_acc:.3f}")
                if curr_acc < target_accuracy and epoch>1:
                    flag_exit = True

            if flag_exit:
                break

            init = False
            #scheduler.step()

        model.eval()
        return model
    
    model = run()
    end = time.time()

    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ),end-start
    



def cfk(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,):
    start = time.time()

    lr = kwargs['lr']
    epochs = kwargs['epochs']

    for param in model.parameters():
        param.requires_grad_(False)
    for param in model.conv5_x.parameters():
        param.requires_grad_(True)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=5e-4)
    scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max =  epochs * len(retain_train_dl) ) #  调整了四分之一周期的长度

    for epoch in tqdm(range(epochs)):
        model.train()
        criterion = nn.CrossEntropyLoss()
        train_loss = 0
        for images, _, targets in tqdm(retain_train_dl):
            images, targets = images.to(device), targets.to(device)
            output = model(images)
            loss = criterion(output, targets) 
            model.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()
        train_loss/= len(retain_train_dl)
        print(f"Epoch:{epoch}| loss:{train_loss:.3f}")
        
        
    end = time.time()

    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ), end-start



def euk(model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,):
    
    start = time.time()
    
    epochs=kwargs['epochs']
    lr = kwargs['lr']
    
    
    for param in model.parameters():
        param.requires_grad_(False)
    source_params = unlearning_teacher.conv5_x.state_dict()
    with torch.no_grad():
        model.conv5_x.load_state_dict(source_params)
    for param in model.conv5_x.parameters():
        param.requires_grad_(True)
        
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=5e-4)
    scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max =  epochs * len(retain_train_dl) ) #  调整了四分之一周期的长度

    for epoch in tqdm(range(epochs)):
        model.train()
        criterion = nn.CrossEntropyLoss()
        train_loss = 0
        for images, _, targets in tqdm(retain_train_dl):
            images, targets = images.to(device), targets.to(device)
            output = model(images)
            loss = criterion(output, targets) 
            model.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()
        train_loss/= len(retain_train_dl)
        print(f"Epoch:{epoch}| loss:{train_loss:.3f}")
    
    end = time.time()
    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ), end-start


def duck(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,
):
    import Alg_Duck
    return Alg_Duck.DUCK(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,
    )
