import os
import torch
import torch.nn as nn
from src.unlearn.random_label import train_rl


def save_gradient_ratio(
    forget_loader,
    model,
    criterion,
    unlearn_lr,
    momentum,
    weight_decay,
    save_path_dir,
    device,
):
    model.to(device)
    optimizer = torch.optim.SGD(
        model.parameters(),
        unlearn_lr,
        momentum=momentum,
        weight_decay=weight_decay,
    )

    gradients = {}

    model.eval()

    for name, param in model.named_parameters():
        gradients[name] = 0

    for i, (image, target) in enumerate(forget_loader):
        image = image.to(device)
        target = target.to(device)

        # compute output
        output_clean = model(image)
        loss = -criterion(output_clean.logits, target)

        optimizer.zero_grad()
        loss.backward()

        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.grad is not None:
                    gradients[name] += param.grad.data

    with torch.no_grad():
        for name in gradients:
            gradients[name] = torch.abs_(gradients[name]).to("cpu")

    threshold_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    for i in threshold_list:
        sorted_dict_positions = {}
        hard_dict = {}

        # Concatenate all tensors into a single tensor
        all_elements = -torch.cat([tensor.flatten() for tensor in gradients.values()])

        # Calculate the threshold index for the top 10% elements
        threshold_index = int(len(all_elements) * i)

        # Calculate positions of all elements
        positions = torch.argsort(all_elements)
        ranks = torch.argsort(positions)

        start_index = 0
        for key, tensor in gradients.items():
            num_elements = tensor.numel()
            # tensor_positions = positions[start_index: start_index + num_elements]
            tensor_ranks = ranks[start_index : start_index + num_elements]

            sorted_positions = tensor_ranks.reshape(tensor.shape)
            sorted_dict_positions[key] = sorted_positions

            # Set the corresponding elements to 1
            threshold_tensor = torch.zeros_like(tensor_ranks)
            threshold_tensor[tensor_ranks < threshold_index] = 1
            threshold_tensor = threshold_tensor.reshape(tensor.shape)
            hard_dict[key] = threshold_tensor
            start_index += num_elements

        torch.save(hard_dict, os.path.join(save_path_dir, "with_{}.pt".format(i)))


def generate_mask(
    forget_loader, model, device, lr, momentum, weight_decay, save_path_dir
):
    criterion = nn.CrossEntropyLoss()
    save_gradient_ratio(
        forget_loader=forget_loader,
        model=model,
        criterion=criterion,
        unlearn_lr=lr,
        momentum=momentum,
        weight_decay=weight_decay,
        save_path_dir=save_path_dir,
        device=device,
    )


def train_unlearn(
    model,
    forget_loader,
    retain_loader,
    loss_fn,
    num_classes,
    device,
    lr,
    num_epochs,
    mask_path: str,
    val_retain_loader=None,
    val_forget_loader=None,
    lr_scheduler: None | str = None,
    patience: int = 5,
    weight_decay: float = 1e-6,
    save_path: None | str = None,
):
    mask = torch.load(mask_path)
    return train_rl(
        model=model,
        forget_loader=forget_loader,
        retain_loader=retain_loader,
        loss_fn=loss_fn,
        num_classes=num_classes,
        device=device,
        lr=lr,
        num_epochs=num_epochs,
        mask=mask,
        val_forget_loader=val_forget_loader,
        val_retain_loader=val_retain_loader,
        lr_scheduler=lr_scheduler,
        patience=patience,
        weight_decay=weight_decay,
        save_path=save_path,
    )
