import copy
import gc

import torch

from scripts.merge.task_vector import TaskVector


def _mask_input_with_mask_rate(
    input_tensor: torch.Tensor, mask_rate: float, use_rescale: bool, mask_strategy: str
):
    assert (
        0.0 <= mask_rate <= 1.0
    ), f"wrong range of mask_rate {mask_rate}, should be [0.0, 1.0]!"
    if mask_strategy == "random":
        torch.manual_seed(42)
        mask = (
            torch.bernoulli(torch.full_like(input=input_tensor, fill_value=mask_rate))
            .half()
            .to(input_tensor.device)
        )
        # num_elements = input_tensor.numel()
        # mask = torch.zeros(num_elements)
        # mask[: int(mask_rate * num_elements)] = 1
        # mask = shuffle(mask, random_state=42)
        # mask = mask.reshape(input_tensor.shape).to(input_tensor.device)

        # masked_input_tensor = input_tensor * (1 - mask)
        input_tensor *= 1 - mask
        del mask
        gc.collect()
        torch.cuda.empty_cache()
    else:
        raise NotImplementedError
        # assert (
        #     mask_strategy == "magnitude"
        # ), f"wrong setting for mask_strategy {mask_strategy}!"
        # original_shape = input_tensor.shape
        # input_tensor = input_tensor.flatten()
        # num_mask_params = int(len(input_tensor) * mask_rate)
        # # Tensor, shape (1, ), find the num_mask_params-th smallest magnitude element of all the parameters in the model
        # kth_values, _ = input_tensor.abs().kthvalue(
        #     k=num_mask_params, dim=0, keepdim=True
        # )
        # # Tensor, shape (num_total_params, ), where True is for parameters that we want to perform mask
        # mask = input_tensor.abs() <= kth_values
        # masked_input_tensor = input_tensor * (~mask)
        # masked_input_tensor = masked_input_tensor.reshape(original_shape)
    if use_rescale and mask_rate != 1.0:
        # masked_input_tensor = torch.div(input=masked_input_tensor, other=1 - mask_rate)
        input_tensor /= 1 - mask_rate
    # return masked_input_tensor


def mask_model_weights(
    task_vector: TaskVector,
    weight_mask_rate: float,
    use_weight_rescale: bool,
    mask_strategy: str,
):
    with torch.no_grad():
        for _, param_value in task_vector.task_vector_param_dict.items():
            _mask_input_with_mask_rate(
                input_tensor=param_value,
                mask_rate=weight_mask_rate,
                use_rescale=use_weight_rescale,
                mask_strategy=mask_strategy,
            )
            del param_value
            torch.cuda.empty_cache()


def mask_model_weights_gpu(
    task_vector: TaskVector,
    weight_mask_rate: float,
    use_weight_rescale: bool,
    mask_strategy: str,
    device: torch.device,
):
    with torch.no_grad():
        for _, param_value in task_vector.task_vector_param_dict.items():
            param_value_gpu = copy.deepcopy(param_value).to(device)
            _mask_input_with_mask_rate(
                input_tensor=param_value_gpu,
                mask_rate=weight_mask_rate,
                use_rescale=use_weight_rescale,
                mask_strategy=mask_strategy,
            )
            param_value.copy_(param_value_gpu.to("cpu"))
            del param_value
            del param_value_gpu
            torch.cuda.empty_cache()
            gc.collect()
