# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import math
import torch
import time

from typing import Callable, Tuple, List, Tuple, Union

def do_nothing(x, mode=None, dst_select_only=None,pos_ids_info=None):
    return x


def bipartite_soft_matching(
    metric: torch.Tensor,
    r: int,
    dst_mask: torch.Tensor,
) -> Tuple[Callable, Callable]:
    """
    Applies ToMe with a balanced matching set (50%, 50%).

    Input size is [batch, tokens, channels].
    r indicates the number of tokens to remove (max 50% of tokens).

    Extra args:
     - class_token: Whether or not there's a class token.
     - distill_token: Whether or not there's also a distillation token.

    When enabled, the class token and distillation tokens won't get merged.
    """
    
    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., ~dst_mask, :], metric[..., dst_mask, :]
        scores = a @ b.transpose(-1, -2)
        
        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
        
    def merge(x: torch.Tensor, mode="mean", dst_select_only=False) -> torch.Tensor:
        src, dst = x[..., ~dst_mask, :], x[..., dst_mask, :]
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))

        if not dst_select_only:
            src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
            dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        return torch.cat([unm, dst], dim=1)
    
    return merge, (unm_idx, src_idx, dst_idx)


def merge_wavg(
    merge: Callable, x: torch.Tensor, size: torch.Tensor = None, mode="sum", avg_procedure=True, dst_select_only=False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Applies the merge function by taking a weighted average based on token size.
    Returns the merged tensor and the new token sizes.
    """
    if size is None:
        size = torch.ones_like(x[..., 0, None])
    
    if avg_procedure:
        x = x * size

    x = merge(x, mode=mode, dst_select_only=dst_select_only)
    size = merge(size, mode=mode, dst_select_only=dst_select_only)
    
    if avg_procedure:    
        x = x / size

    return x, size
