import torch
from torch import Tensor

from typing import Dict

def topk_values_mask(M, K=0.8, return_mask=False):
    if K > 1:
        K /= 100

    original_shape = M.shape
    if M.dim() == 1:
        M = M.unsqueeze(0)

    n, d = M.shape
    k = int(d * K)
    k = d - k  # Keep top k elements instead of bottom k elements

    kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
    mask = M.abs() >= kth_values
    final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask

    if return_mask:
        # return M * final_mask, final_mask.float().mean(dim=1), final_mask
        return M * final_mask, final_mask.float().mean(dim=0), final_mask
    # return M * final_mask, final_mask.float().mean(dim=1)
    return M * final_mask, final_mask.float().mean(dim=0)


def resolve_zero_signs(sign_to_mult, method="majority"):
    majority_sign = torch.sign(sign_to_mult.sum())

    if method == "majority":
        sign_to_mult[sign_to_mult == 0] = majority_sign
    elif method == "minority":
        sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
    return sign_to_mult


def resolve_sign(Tensor):
    sign_to_mult = torch.sign(Tensor.sum(dim=0))
    sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
    return sign_to_mult


def disjoint_merge(Tensor, merge_func, sign_to_mult):

    merge_func = merge_func.split("-")[-1]

    if sign_to_mult is not None:
        rows_to_keep = torch.where(
            sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
        )
        selected_entries = Tensor * rows_to_keep
    # Else we select all non-zero entries and aggregate.
    else:
        rows_to_keep = Tensor != 0
        selected_entries = Tensor * rows_to_keep

    if merge_func == "mean":
        non_zero_counts = (selected_entries != 0).sum(dim=0).float()
        disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
            non_zero_counts, min=1
        )
    elif merge_func == "sum":
        disjoint_aggs = torch.sum(selected_entries, dim=0)
    elif merge_func == "max":
        disjoint_aggs = selected_entries.abs().max(dim=0)[0]
        disjoint_aggs *= sign_to_mult
    else:
        raise ValueError(f"Merge method {merge_func} is not defined.")

    return disjoint_aggs


# def their_ties_merging(
#     flat_task_checks,
#     reset_thresh=None,
#     merge_func="",
# ):
#     all_checks = flat_task_checks.clone()
#     updated_checks, *_ = topk_values_mask(
#         all_checks, K=reset_thresh, return_mask=False
#     )
#     print(f"RESOLVING SIGN")
#     final_signs = resolve_sign(updated_checks)
#     assert final_signs is not None
    
#     print(f"Disjoint AGGREGATION: {merge_func}")
#     merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
    
#     return merged_tv

def their_ties_merging(
    task_vectors: Dict[str, Tensor],
    reset_thresh=None,
    merge_func="",
):

    flat_task_checks = torch.cat([v.flatten() for v in task_vectors.values()])
    print(f"flat_task_checks.shape: {flat_task_checks.shape}")
    all_checks = flat_task_checks.clone()
    print(f"all_checks.shape: {all_checks.shape}")

    updated_checks, *_ = topk_values_mask(
        all_checks, K=reset_thresh, return_mask=False
    )
    
    print(f"RESOLVING SIGN")
    final_signs = resolve_sign(updated_checks)
    assert final_signs is not None
    
    print(f"Disjoint AGGREGATION: {merge_func}")
    merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
    
    return merged_tv


if __name__ == "__main__":
    task_vectors = {
        "CIFAR100": torch.randn(86_000000),
        "CIFAR10": torch.randn(86_000000),
        "SVHN": torch.randn(86_000000),
        "MNIST": torch.randn(86_000000)
    }

    reset_thresh = 0.8
    merge_func = "mean"

    their_ties_merging(
        task_vectors=task_vectors,
        reset_thresh=reset_thresh,
        merge_func=merge_func
    )