import torch
from masking_ops import chunked_sum


def tv_merging(vectors, weights=None, merging_type='sum', **kwargs):

    rows_to_keep = [torch.ones_like(vectors[i]) for i in range(len(vectors))]
    vectors_ = torch.vstack(vectors).clone()
    if merging_type == 'unmerged':
        return vectors_, rows_to_keep

    if weights is not None:
        for vector in vectors_:
            vector *= weights[0]
    if merging_type == 'mean':
        result = torch.mean(vectors_, dim=0)
    else:
        result = torch.sum(vectors_, dim=0)
    return result, rows_to_keep, rows_to_keep


# TIES MERGING UTILS
def topk_values_mask(M, K=0.7, return_mask=False):
    if K > 1:
        K /= 100

    if K >= 1 and return_mask:
        return M, torch.ones_like(M).mean(dim=-1), torch.ones_like(M)
    elif K >= 1:
        return M, torch.ones_like(M).mean(dim=-1)

    original_shape = M.shape
    if M.dim() == 1:
        M = M.unsqueeze(0)

    n, d = M.shape
    k = int(d * K)
    k = d - k  # Keep top k elements instead of bottom k elements
    # Find the k-th smallest element by magnitude for each row
    if M.flatten().shape[-1] == 1:
        kth_values = M.abs()
    else:
        kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
    # Create a mask tensor with True for the top k elements in each row
    mask = M.abs() >= kth_values
    if original_shape == M.squeeze().shape:
        final_mask = mask.squeeze()
        M = M.squeeze()
    else:
        final_mask = mask

    if return_mask:
        return M * final_mask, final_mask.float().mean(dim=-1), final_mask
    return M * final_mask, final_mask.float().mean(dim=-1)


def resolve_zero_signs(sign_to_mult, method="majority"):
    majority_sign = torch.sign(sign_to_mult.sum())

    if method == "majority":
        sign_to_mult[sign_to_mult == 0] = majority_sign
    elif method == "minority":
        sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
    return sign_to_mult


def disjoint_merge(Tensor, merge_func, reference_sign_to_mult, weights=None):
    # If sign is provided then we select the corresponding entries and aggregate.
    if reference_sign_to_mult is not None:
        rows_to_keep = torch.where(
            reference_sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
        )
    # Else we select all non-zero entries and aggregate.
    else:
        rows_to_keep = Tensor != 0

    selected_entries = Tensor * rows_to_keep
    if weights is not None:
        for selected_entrie in selected_entries:
            selected_entrie *= weights[0]
    if merge_func == "mean":
        non_zero_counts = (selected_entries != 0).sum(dim=0).float()
        disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
            non_zero_counts, min=1
        )
    elif merge_func == "sum":
        disjoint_aggs = chunked_sum(selected_entries)
    elif merge_func == "max":
        disjoint_aggs = selected_entries.abs().max(dim=0)[0]
        disjoint_aggs *= reference_sign_to_mult
    elif merge_func == 'unmerged':
        disjoint_aggs = selected_entries
    else:
        raise ValueError(f"Merge method {merge_func} is not defined.")

    return disjoint_aggs, rows_to_keep


def resolve_sign(Tensor, mode=None):
    sign_to_mult = torch.sign(Tensor.sum(dim=0))
    sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
    return sign_to_mult


def ties_merging(vectors, topK=10, merging_type='mean', weights=None, **kwargs):
    # Add functionality that allows some layers to not be pruned or lets them be skipped
    # print(f'TopK is: {topK}')
    # print(f'weights is: {weights}')
    stacked_vectors = torch.vstack(vectors) if isinstance(vectors, list) else vectors
    pruned_vectors, _, mask = topk_values_mask(stacked_vectors, K=topK, return_mask=True)
    vector_signs = resolve_sign(pruned_vectors)
    assert vector_signs is not None
    merged_tv, rows_to_keep = disjoint_merge(pruned_vectors, merging_type, vector_signs, weights)
    return merged_tv, rows_to_keep, mask


def normalize(x, dim=0):
    min_values, _ = torch.min(x, dim=dim, keepdim=True)
    max_values, _ = torch.max(x, dim=dim, keepdim=True)
    y = (x - min_values) / (max_values - min_values)
    return y

def clamp(x, min_ratio=0, max_ratio=0):
    import builtins
    if len(x.size())==1:
        d = x.size(0)
        sorted_x, _ = torch.sort(x)
        min=sorted_x[builtins.min(int(d * min_ratio), d-1)]
        max=sorted_x[int(d * (1-max_ratio)-1)]
    else:
        d = x.size(1)
        sorted_x, _ = torch.sort(x, dim=1)
        min=sorted_x[:, builtins.min(int(d * min_ratio), d-1)].unsqueeze(1)
        max=sorted_x[:, int(d * (1-max_ratio)-1)].unsqueeze(1)
    clamped_x= torch.clamp(x, min, max)
    return clamped_x

def act(x):
    y = torch.tanh(x)  # torch.relu(x)
    return y

def PCB_merge(flat_task_checks, pcb_ratio=0.1):
    all_checks = flat_task_checks.clone()
    n, d = all_checks.shape   
    all_checks_abs = clamp(torch.abs(all_checks), min_ratio=0.0001, max_ratio=0.0001)
    clamped_all_checks = torch.sign(all_checks)*all_checks_abs
    self_pcb = normalize(all_checks_abs, 1)**2
    self_pcb_act = torch.exp(n*self_pcb)
    cross_pcb = all_checks * torch.sum(all_checks, dim=0)
    cross_pcb_act = act(cross_pcb)
    task_pcb = self_pcb_act * cross_pcb_act

    scale = normalize(clamp(task_pcb, 1-pcb_ratio, 0), dim=1)
    tvs = clamped_all_checks
    merged_tv = torch.sum(tvs * scale, dim=0) / torch.clamp(torch.sum(scale, dim=0), min=1e-12)
    return merged_tv, clamped_all_checks, scale