import sys
import time

import torch

import utils

from .impl import iterative_unlearn

sys.path.append(".")
from imagenet import get_x_y_from_data_dict
from torch.utils.data import ConcatDataset, DataLoader, RandomSampler
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from tqdm import tqdm
class MaskedDataset(Dataset):
    def __init__(self, forget_set, retain_set, mask):
        super(MaskedDataset, self).__init__()
        self.forget_set = forget_set
        self.retain_set = retain_set
        self.mask = mask
        self.forget_len = len(forget_set)
        assert len(mask) == len(forget_set) + len(retain_set), "Mask length must match combined dataset length."

    def __len__(self):
        return len(self.mask)

    def __getitem__(self, idx):
        if self.mask[idx] == 0:
            image, target = self.forget_set[idx]
            source = 0  
        else:
            adjusted_idx = idx - len(self.forget_set)
            image, target = self.retain_set[adjusted_idx]
            source = 1  

        return image, target, source



def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)


@iterative_unlearn
def NegGrad_plus(data_loaders, model, criterion, optimizer, epoch, args):
    forget_loader = data_loaders["forget"]
    remain_loader = data_loaders["retain"]
    
    forget_set = forget_loader.dataset
    retain_set = remain_loader.dataset
    mask = [0] * len(forget_set) + [1] * len(retain_set)
    combined_dataset = MaskedDataset(forget_set, retain_set, mask)
    combined_dataset_loader = DataLoader(combined_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
    
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()

    model.train()

    start = time.time()
    beta = args.beta_neggrad_plus
    
    for i, (image, target, source) in enumerate(tqdm(combined_dataset_loader)):  
        
        image = image.cuda()
        target = target.cuda()
        source = source.cuda()

        output_clean = model(image)
        num_retain = source.sum()
        
        pos_position = (source == 1)
        target_select_r = target[pos_position]
        output_output_r = output_clean[pos_position]
        
        neg_position = (source == 0)
        target_select_f = target[neg_position]
        output_output_f = output_clean[neg_position]        
                
        loss = beta*criterion(output_output_r, target_select_r)/num_retain - (1-beta)*criterion(output_output_f, target_select_f)/(args.batch_size - num_retain)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return 0
