import torch

from .utils import topk_values_mask


## TIES MERGING UTILS
def resolve_zero_signs(sign_to_mult, method="majority"):
    majority_sign = torch.sign(sign_to_mult.sum())

    if method == "majority":
        sign_to_mult[sign_to_mult == 0] = majority_sign
    elif method == "minority":
        sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
    return sign_to_mult


def resolve_sign(tensor: torch.Tensor):
    sign_to_mult = torch.sign(tensor.sum(dim=0))
    sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
    return sign_to_mult


def disjoint_merge(tensor, merge_func, sign_to_mult):
    merge_func = merge_func.split("-")[-1]

    # If sign is provided then we select the corresponding entries and aggregate.
    if sign_to_mult is not None:
        rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, tensor > 0, tensor < 0)
        selected_entries = tensor * rows_to_keep
    # Else we select all non-zero entries and aggregate.
    else:
        rows_to_keep = tensor != 0
        selected_entries = tensor * rows_to_keep

    if merge_func == "mean":
        non_zero_counts = torch.zeros(selected_entries.size(1))

        for task_vector in selected_entries:
            non_zero_counts += (task_vector != 0).float()
        
        disjoint_aggs = torch.zeros(selected_entries.size(1))

        for task_vector in selected_entries:
            disjoint_aggs += task_vector

        disjoint_aggs /= torch.clamp(non_zero_counts, min=1)
    elif merge_func == "sum":
        disjoint_aggs = torch.zeros(selected_entries.size(1))

        for task_vector in selected_entries:
            disjoint_aggs += task_vector
    elif merge_func == "max":
        disjoint_aggs = selected_entries.abs().max(dim=0)[0]
        disjoint_aggs *= sign_to_mult
    else:
        raise ValueError(f"Merge method {merge_func} is not defined.")

    return disjoint_aggs


def ties_merging(flat_task_checks, reset_thresh=None, merge_func=""):
    all_checks = flat_task_checks.clone()
    updated_checks = topk_values_mask(all_checks, K=reset_thresh)
    print(f"RESOLVING SIGN")
    final_signs = resolve_sign(updated_checks)
    assert final_signs is not None

    print(f"Disjoint AGGREGATION: {merge_func}")
    merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)

    return merged_tv
