import os 
import torch
from torch import nn
from methods.method import Method
from copy import deepcopy
from methods.method import Method
from tqdm import tqdm
from utils.data import RandomDataset


class SALUN(Method):

    def set_hyperparameters(self, args):

        self.unlearn_lr = self.args.lr#0.001
        self.unlearn_epochs = self.args.remain_epochs
        self.decreasing_lr = "91,136"
        self.momentum = 0.9
        self.weight_decay = 5e-4
        self.saliency_thres = 0.5
        self.mask_ratio = 0.3
        self.warmup = 0 
        self.save_dir = f'./methods/saliency_ckpt/{args.test_mode}/{args.data_name}/{args.model_name}'
        self.saliency_mask_path = os.path.join(self.save_dir, f"with_{self.saliency_thres}.pt")


    def warmup_lr(self, epoch, step, optimizer, one_epoch_step):
        overall_steps = self.warmup * one_epoch_step
        current_steps = epoch * one_epoch_step + step

        lr = self.lr * current_steps / overall_steps
        lr = min(lr, self.lr)

        for p in optimizer.param_groups:
            p["lr"] = lr


    def generate_saliency_mask(self, loaders, model, args):
        criterion = nn.CrossEntropyLoss() 
        optimizer = torch.optim.SGD(
            model.parameters(),
            self.unlearn_lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
        )

        gradients = {}

        forget_loader = loaders['train_forget_loader']
        model.eval()

        for name, param in model.named_parameters():
            gradients[name] = 0

        for i, (image, target) in enumerate(forget_loader):
            image, target = image.cuda(), target.cuda()

            output_clean = model(image)
            loss = - criterion(output_clean, target)

            optimizer.zero_grad()
            loss.backward()

            with torch.no_grad():
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        gradients[name] = param.grad.data
                    else:
                        gradients[name] = torch.zeros_like(param)

        with torch.no_grad():
            for name in gradients:
                gradients[name] = torch.abs_(gradients[name])

        threshold_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

        for i in threshold_list:
            print(i)
            sorted_dict_positions = {}
            hard_dict = {}

            # Concatenate all tensors into a single tensor
            all_elements = - torch.cat([tensor.flatten() for tensor in gradients.values()])

            # Calculate the threshold index for the top 10% elements
            threshold_index = int(len(all_elements) * i)

            # Calculate positions of all elements
            positions = torch.argsort(all_elements)
            ranks = torch.argsort(positions)

            start_index = 0
            for key, tensor in gradients.items():
                num_elements = tensor.numel()
                # tensor_positions = positions[start_index: start_index + num_elements]
                tensor_ranks = ranks[start_index : start_index + num_elements]

                sorted_positions = tensor_ranks.reshape(tensor.shape)
                sorted_dict_positions[key] = sorted_positions

                # Set the corresponding elements to 1
                threshold_tensor = torch.zeros_like(tensor_ranks)
                threshold_tensor[tensor_ranks < threshold_index] = 1
                threshold_tensor = threshold_tensor.reshape(tensor.shape)
                hard_dict[key] = threshold_tensor
                start_index += num_elements

            torch.save(hard_dict, os.path.join(self.save_dir, "with_{}.pt".format(i)))

    
    def unlearn(self, model, loaders, args):

        #generate salun mask, if mask does not exist
        if not os.path.exists(self.saliency_mask_path): 
            os.makedirs(self.save_dir, exist_ok=True)
            self.generate_saliency_mask(loaders, model, args)
            
        mask = torch.load(self.saliency_mask_path)

        criterion = nn.CrossEntropyLoss()
        optimizer = self.get_optimizer(model)
        decreasing_lr = list(map(int, self.decreasing_lr.split(",")))
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=decreasing_lr, gamma=0.1
            ) 

        forget_loader, retain_loader = loaders['train_forget_loader'], loaders['train_remain_loader']
        forget_set, retain_set = loaders['train_forget_set'], loaders['train_remain_set']
        forget_set = RandomDataset(deepcopy(forget_set), args) #change forget_set to random data


        for epoch in range(self.unlearn_epochs):
            if args.data_name in ['cifar100', 'imagenet']:

                train_dataset = torch.utils.data.ConcatDataset([forget_set, retain_set])
                train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.remain_batch_size, shuffle=True)

                model.train()
                loader_len = len(forget_loader) + len(retain_loader)
            
                if epoch < self.warmup:
                    self.warmup_lr(epoch, i+1, optimizer,
                                    one_epoch_step=loader_len)

                for it, (image, target) in enumerate(tqdm(train_loader)):
                    # i = it + len(forget_loader)

                    image, target = image.to(args.device), target.to(args.device)
                    output = model(image)

                    loss = criterion(output, target)
                    optimizer.zero_grad()
                    loss.backward()
                    
                    if mask:
                        for name, param in model.named_parameters():
                            if param.grad is not None:
                                param.grad *= mask[name]

                    optimizer.step()
            
                    # if it % args.test_interval == 0: self.intermidiate_test(model)

            elif args.data_name in ['cifar10']:

                loader_len = len(forget_loader) + len(retain_loader)
        
                if epoch < self.warmup:
                    self.warmup_lr(epoch, i+1, optimizer,
                                    one_epoch_step=loader_len)
            
                for i, (image, target) in enumerate(tqdm(forget_loader)):
                    image = image.cuda()
                    target = torch.randint(0, args.num_classes, target.shape).cuda()
                    
                    # compute output
                    output_clean = model(image)
                    loss = criterion(output_clean, target)
                    
                    optimizer.zero_grad()
                    loss.backward()
                    
                    if mask:
                        for name, param in model.named_parameters():
                            if param.grad is not None:
                                param.grad *= mask[name]
                    
                    optimizer.step()
                
                for i, (image, target) in enumerate(tqdm(retain_loader)):
                    image = image.cuda()
                    target = target.cuda()
                    
                    # compute output
                    output_clean = model(image)
                    loss = criterion(output_clean, target)
                    
                    optimizer.zero_grad()
                    loss.backward()
                    
                    if mask:
                        for name, param in model.named_parameters():
                            if param.grad is not None:
                                param.grad *= mask[name]
                    
                    optimizer.step()

            else:
                raise NotImplementedError

            scheduler.step()

        
        return model