import os
import torch
import torch.optim
import torch.utils.data


def save_gradient_ratio(forget_loader, model, criterion, optimizer, save_dir):
    gradients = {}
    model.eval()

    for name, param in model.named_parameters():
        gradients[name] = 0

    for i, (image, target) in enumerate(forget_loader):
        image = image.cuda()
        target = target.cuda()

        # compute output
        output_clean, new_outputs = model(image)
        loss = - criterion(output_clean, target)

        optimizer.zero_grad()
        loss.backward()

        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.grad is not None:
                    gradients[name] += param.grad.data
    
    with torch.no_grad():
        for name in gradients:
            if isinstance(gradients[name], torch.Tensor):
                gradients[name] = gradients[name].abs_()
            else:
                gradients[name] = abs(gradients[name])
            # gradients[name] = torch.abs_(gradients[name])

    threshold_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    for i in threshold_list:
        sorted_dict_positions = {}
        hard_dict = {}

        # Concatenate all tensors into a single tensor
        # all_elements = - torch.cat([tensor.flatten() for tensor in gradients.values()])
        all_elements = - torch.cat([
            tensor.flatten() for tensor in gradients.values() if isinstance(tensor, torch.Tensor)
        ])
        # Calculate the threshold index for the top 10% elements
        threshold_index = int(len(all_elements) * i)

        # Calculate positions of all elements
        positions = torch.argsort(all_elements)
        ranks = torch.argsort(positions)

        start_index = 0
        for key, tensor in gradients.items():
            if not isinstance(tensor, torch.Tensor):
                # print(f"Skipping key '{key}' — not a tensor (type: {type(tensor)})")
                continue

            num_elements = tensor.numel()
            # tensor_positions = positions[start_index: start_index + num_elements]
            tensor_ranks = ranks[start_index : start_index + num_elements]

            sorted_positions = tensor_ranks.reshape(tensor.shape)
            sorted_dict_positions[key] = sorted_positions

            # Set the corresponding elements to 1
            threshold_tensor = torch.zeros_like(tensor_ranks)
            threshold_tensor[tensor_ranks < threshold_index] = 1
            threshold_tensor = threshold_tensor.reshape(tensor.shape)
            hard_dict[key] = threshold_tensor
            start_index += num_elements

        torch.save(hard_dict, os.path.join(save_dir, "with_{}.pt".format(i)))
