import gc
import sys

import torch
import torch.nn as nn

from scripts.merge.task_vector import TaskVector
from scripts.utils.pure import print_cpu_memory_usage


def _task_vector_param_dict_to_single_vector(task_vector: TaskVector):
    # task_vector_param_dict = copy.deepcopy(task_vector.task_vector_param_dict)
    # sorted_task_vector_param_dict = OrderedDict(sorted(task_vector_param_dict.items()))

    # # Tensor, shape (num_total_params, )
    # return nn.utils.parameters_to_vector(
    #     [param.flatten() for param in sorted_task_vector_param_dict.values()]
    # )

    res = nn.utils.parameters_to_vector(
        [
            task_vector.task_vector_param_dict[name].flatten()
            for name in task_vector.base_model_param_names
        ]
    )
    del task_vector.task_vector_param_dict
    del task_vector
    torch.cuda.empty_cache()
    gc.collect()
    return res


def _single_vector_to_task_vector_param_dict(
    single_vector: torch.Tensor, base_model: nn.Module
):
    """
    convert a single vector to parameter dictionary in task vector
    :param single_vector: Tensor, single vector that contain all parameters in task_vector.task_vector_param_dict
    :param task_vector: TaskVector, task vector
    :return:
    """
    # task_vector_param_dict = copy.deepcopy(task_vector.task_vector_param_dict)
    # sorted_task_vector_param_dict = OrderedDict(sorted(task_vector_param_dict.items()))

    parameters = base_model.parameters()
    nn.utils.vector_to_parameters(single_vector, parameters)
    param_dic = {name: param for name, param in base_model.named_parameters()}

    del parameters
    del base_model
    del single_vector
    torch.cuda.empty_cache()
    gc.collect()

    return param_dic


def _mask_smallest_magnitude_param_values(
    flattened_models_to_merge_param: torch.Tensor,
    flattened_model_names: list[str],
    param_value_mask_rate: float | dict[str, float],
):
    """
    mask the smallest-magnitude parameter values (set to zeros) based on parameter value mask rate
    :param flattened_models_to_merge_param: Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged
    :param param_value_mask_rate: float, mask rate of the smallest-magnitude parameter values
    :return:
    """
    # if isinstance(param_value_mask_rate, float):
    #     # num_models_to_merge, num_total_params = flattened_models_to_merge_param.shape
    #     num_mask_params = int(
    #         flattened_models_to_merge_param.shape[1] * param_value_mask_rate
    #     )

    #     # Tensor, shape (num_models_to_merge, 1), find the num_mask_params-th smallest magnitude element of all the parameters in each individual model
    #     kth_values: torch.Tensor | list[torch.Tensor] = (
    #         flattened_models_to_merge_param.abs().kthvalue(k=num_mask_params, dim=1, keepdim=True)[0]
    #     )
    # elif isinstance(param_value_mask_rate, dict):
    #     num_mask_params_ls = [
    #         flattened_models_to_merge_param.shape[1] * param_value_mask_rate[name]
    #         for name in flattened_model_names
    #     ]
    #     kth_values = [
    #         param.abs().kthvalue(k=num_mask_params, dim=0, keepdim=True)[0]
    #         for param, num_mask_params in zip(
    #             flattened_models_to_merge_param, num_mask_params_ls
    #         )
    #     ]

    # Tensor, shape (num_models_to_merge, num_total_params), where True is for parameters that we want to preserve
    # mask = flattened_models_to_merge_param.abs() >= kth_values

    print(">masking smallest magnitude")
    print(f"device: {flattened_models_to_merge_param.device}")
    print_cpu_memory_usage()
    sys.stdout.flush()

    for flattened_model, name in zip(
        flattened_models_to_merge_param, flattened_model_names
    ):
        mask_rate = (
            param_value_mask_rate[name]
            if isinstance(param_value_mask_rate, dict)
            else param_value_mask_rate
        )
        num_mask_params = int(flattened_models_to_merge_param.shape[1] * mask_rate)

        _mask_smallest(flattened_model, num_mask_params)

    # return flattened_models_to_merge_param * mask


def _mask_smallest(flattened_model: torch.Tensor, num_mask_params: int):
    print(">>masking smallest magnitude...")
    print_cpu_memory_usage()
    flatten_model_abs = flattened_model.abs()
    sys.stdout.flush()
    # kth_value = flatten_model_abs.kthvalue(k=num_mask_params, dim=0)[0]
    assert (flatten_model_abs >= 64.0).sum(dtype=torch.float32) == 0
    gc.collect()
    left = 0.0
    right = 64.0
    for _ in range(22):
        mid = (left + right) / 2
        if (flatten_model_abs < mid).sum(dtype=torch.float32) < num_mask_params:
            left = mid
        else:
            right = mid
        gc.collect()
        torch.cuda.empty_cache()
    kth_value = left

    print_cpu_memory_usage()
    sys.stdout.flush()
    mask = flatten_model_abs >= kth_value
    print_cpu_memory_usage()
    sys.stdout.flush()
    del flatten_model_abs
    gc.collect()
    print_cpu_memory_usage()
    sys.stdout.flush()
    # flattened_model[flattened_model.abs() < kth_value] = 0.0
    flattened_model *= mask
    print_cpu_memory_usage()
    sys.stdout.flush()
    del flattened_model
    del mask
    # torch.cuda.empty_cache()
    gc.collect()
    print_cpu_memory_usage()
    sys.stdout.flush()


def _get_param_signs(
    flattened_models_to_merge_param: torch.Tensor | list[torch.Tensor],
):
    """
    get the signs for each parameter in flattened_models_to_merge_param, computed over individual models that need to be merged
    :param flattened_models_to_merge_param: Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged
    :return:
    """
    print(">getting param signs...")
    # Tensor, shape (num_total_params, ), the signs of parameters aggregated across individual models that need to be merged
    # s = flattened_models_to_merge_param.sum(dim=0)
    s = 0.0
    for param in flattened_models_to_merge_param:
        s += param
        del param
        torch.cuda.empty_cache()
        gc.collect()
    assert isinstance(s, torch.Tensor)
    print_cpu_memory_usage()
    pos = (s > 0.0).half()
    neg = (s < 0.0).half()
    param_signs = pos - neg
    print_cpu_memory_usage()
    del s
    del pos
    del neg
    torch.cuda.empty_cache()
    gc.collect()
    # param_signs = torch.sign(flattened_models_to_merge_param.sum(dim=0)).half()
    print(">getting param signs done")
    print_cpu_memory_usage()
    # Tensor, shape (, ), a scalar, replace 0 in param_signs to the major sign in param_signs
    majority_sign = torch.sign(param_signs.sum(dim=0))
    majority_sign = (param_signs == 0).half() * majority_sign
    param_signs += majority_sign
    print(">getting majority sign done")
    print_cpu_memory_usage()
    del majority_sign
    del flattened_models_to_merge_param
    torch.cuda.empty_cache()
    gc.collect()
    return param_signs


# def _disjoint_merge(
#     flattened_models_to_merge_param: torch.Tensor,
#     param_signs: torch.Tensor,
#     flattened_model_names: list[str],
#     each_model_weight: dict[str, float] | None = None,
# ):
#     """
#     disjoint merge that only keeps the parameter values in individual models whose signs are the same as the param_signs, and calculates the averaged parameters.
#     :param flattened_models_to_merge_param: Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged
#     :param param_signs: Tensor, shape (num_total_params, ), the signs of parameters aggregated across individual models that need to be merged
#     :return:
#     """

#     num_preserved = (
#         torch.zeros_like(param_signs).half().to(flattened_models_to_merge_param.device)
#     )
#     for name, param in zip(flattened_model_names, flattened_models_to_merge_param):
#         mask = ((param_signs > 0) & (param > 0)) | ((param_signs < 0) & (param < 0))
#         param *= mask
#         if each_model_weight is not None:
#             param *= each_model_weight[name]
#         num_preserved += mask
#         del mask
#         del param
#         torch.cuda.empty_cache()
#     del param_signs
#     torch.cuda.empty_cache()

#     # # Tensor, shape (num_models_to_merge, num_total_params), where True is for parameters that we want to preserve
#     # param_to_preserve_mask = (
#     #     (param_signs.unsqueeze(dim=0) > 0) & (flattened_models_to_merge_param > 0)
#     # ) | ((param_signs.unsqueeze(dim=0) < 0) & (flattened_models_to_merge_param < 0))

#     # if each_model_weight is not None:
#     #     weight = torch.tensor(
#     #         [w for w in OrderedDict(sorted(each_model_weight.items())).values()]
#     #     ).unsqueeze(1)
#     #     param_to_preserve_mask *= weight

#     # # Tensor, shape (num_models_to_merge, num_total_params), the preserved parameters
#     # param_to_preserve = flattened_models_to_merge_param * param_to_preserve_mask

#     # # Tensor, shape (num_total_params, ), the number of models whose parameters can be preserved
#     # num_models_param_preserved = (param_to_preserve_mask != 0).sum(dim=0).float()
#     # # Tensor, shape (num_total_params, ), the averaged flattened parameters
#     # merged_flattened_param = torch.sum(param_to_preserve, dim=0) / torch.clamp(
#     #     num_models_param_preserved, min=1.0
#     # )

#     # return merged_flattened_param


# def ties_merging(
#     base_model: nn.Module,
#     task_vectors: dict[str, TaskVector],
#     param_value_mask_rate: float | dict[str, float],
#     scaling_coefficient: float,
#     each_model_weight: dict[str, float] | None = None,
# ):
#     assert isinstance(
#         scaling_coefficient, float
#     ), "wrong type of scaling_coefficient, should be float!"

#     with torch.no_grad():
#         flattened_model_names = list(task_vectors.keys())
#         flattened_models_to_merge_param_ls = []
#         for name in flattened_model_names:
#             flattened_models_to_merge_param_ls.append(
#                 _task_vector_param_dict_to_single_vector(task_vector=task_vectors[name])
#             )
#             del task_vectors[name]
#             torch.cuda.empty_cache()
#             gc.collect()

#         print("making flattened models done")
#         # print_gpu_memory_usage()

#         # Tensor, shape (num_models_to_merge, num_total_params), flattened parameters of individual models that need to be merged
#         flattened_models_to_merge_param = torch.vstack(
#             flattened_models_to_merge_param_ls
#         )
#         del flattened_models_to_merge_param_ls

#         # Tensor, shape (num_models_to_merge, num_total_params), mask the smallest-magnitude parameter values using param_value_mask_rate
#         _mask_smallest_magnitude_param_values(
#             flattened_models_to_merge_param,
#             flattened_model_names,
#             param_value_mask_rate,
#         )
#         print("masking smallest magnitude done")
#         print_gpu_memory_usage()

#         # Tensor, shape (num_total_params, ), get the signs for each parameter in flattened_models_to_merge_param
#         param_signs = _get_param_signs(flattened_models_to_merge_param)
#         print("getting param signs done")
#         print_gpu_memory_usage()

#         # # Tensor, shape (num_total_params, ), disjoint merge
#         # merged_flattened_param = _disjoint_merge(
#         #     flattened_models_to_merge_param,
#         #     param_signs,
#         #     each_model_weight,
#         # )

#         ## merge -----------------------------------------------------------------------
#         num_preserved = 0.0
#         merged_param = 0.0
#         flattened_models_to_merge_param_dic = {
#             name: param
#             for name, param in zip(
#                 flattened_model_names, flattened_models_to_merge_param
#             )
#         }
#         del flattened_models_to_merge_param
#         for name in flattened_model_names:
#             param = flattened_models_to_merge_param_dic[name]
#             mask = ((param_signs > 0) & (param > 0)) | ((param_signs < 0) & (param < 0))
#             param *= mask
#             if each_model_weight is not None:
#                 param *= each_model_weight[name]
#             if isinstance(merged_param, float):
#                 merged_param = param
#             else:
#                 merged_param += param
#             del flattened_models_to_merge_param_dic[name]
#             del param
#             # torch.cuda.empty_cache()
#             gc.collect()
#             if isinstance(num_preserved, float):
#                 num_preserved = mask
#             else:
#                 num_preserved += mask
#             del mask
#             # torch.cuda.empty_cache()
#             gc.collect()
#         del param_signs
#         del flattened_models_to_merge_param_dic
#         # torch.cuda.empty_cache()
#         gc.collect()
#         ## ------------------------------------------------------------------------------

#         # merged parameter dictionary
#         assert isinstance(merged_param, torch.Tensor)
#         merged_task_vector_param_dict = _single_vector_to_task_vector_param_dict(
#             single_vector=merged_param,
#             base_model=base_model,
#         )
#         del merged_param
#         # torch.cuda.empty_cache()
#         gc.collect()

#         merged_task_vector = TaskVector.from_param_dict(
#             merged_task_vector_param_dict,
#             base_model,
#         )
#         del merged_task_vector_param_dict
#         # torch.cuda.empty_cache()
#         gc.collect()

#         # combine with parameters of the merged model based on scaling coefficient
#         merged_params = merged_task_vector.combine_with_pretrained_model(
#             base_model=base_model, scaling_coefficient=scaling_coefficient
#         )
#         del merged_task_vector.task_vector_param_dict
#         del merged_task_vector
#         # torch.cuda.empty_cache()
#         gc.collect()

#     return merged_params
