import sys
import time

import torch

import utils

from .impl import iterative_unlearn
import copy
import os
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim
import torch.utils.data
sys.path.append(".")
from imagenet import get_x_y_from_data_dict
from torch.utils.data import ConcatDataset, DataLoader, RandomSampler
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from thirdparty.repdistiller.distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss


from tqdm import tqdm
class MaskedDataset(Dataset):
    def __init__(self, forget_set, retain_set, mask):
        super(MaskedDataset, self).__init__()
        self.forget_set = forget_set
        self.retain_set = retain_set
        self.mask = mask
        self.forget_len = len(forget_set)
        assert len(mask) == len(forget_set) + len(retain_set), "Mask length must match combined dataset length."

    def __len__(self):
        return len(self.mask)

    def __getitem__(self, idx):
        if self.mask[idx] == 0:
            image, target = self.forget_set[idx]
            source = 0  
        else:
            adjusted_idx = idx - len(self.forget_set)
            image, target = self.retain_set[adjusted_idx]
            source = 1  

        return image, target, source



def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)


@iterative_unlearn
def SCRUB(data_loaders, model_t, model, criterion, optimizer, epoch, args):

    s_optim = 'sgd'
    s_gamma = 0.99
    s_alpha = 0.001
    s_beta = 0
    s_smoothing = 0.0
    s_msteps = 2
    s_clip = 0.2
    s_sstart = 10
    s_kd_T = 4
    s_distill = 'kd'

    s_sgda_batch_size = 128
    s_del_batch_size = 32
    s_sgda_epochs = 3 
    s_sgda_learning_rate = 0.0005
    s_lr_decay_epochs = [3,5,9]
    s_lr_decay_rate = 0.1
    s_sgda_weight_decay = 5e-4
    s_sgda_momentum = 0.9

    criterion_cls = nn.CrossEntropyLoss()
    criterion_div = DistillKL(s_kd_T)
    
    forget_loader = data_loaders["forget"]
    remain_loader = data_loaders["retain"]
    
    model.train()
    
    if epoch<args.scrub_forget_epoch:
        for i, (image, target) in enumerate(tqdm(forget_loader)):  
            image = image.cuda()
            target = target.cuda()
            logit_s = model(image)
            with torch.no_grad():
                logit_t = model_t(image)
            loss_div = criterion_div(logit_s, logit_t)
            
            loss = -loss_div
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
    for i, (image, target) in enumerate(tqdm(remain_loader)):  
        image = image.cuda()
        target = target.cuda()
        logit_s = model(image)
        with torch.no_grad():
            logit_t = model_t(image)
        loss_cls = criterion_cls(logit_s, target)
        loss_div = criterion_div(logit_s, logit_t)
        
        loss = s_gamma * loss_cls + s_alpha * loss_div
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return 0