# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import math
from typing import Callable, Tuple

import torch


def do_nothing(x, mode=None):
    return x


def bipartite_soft_matching(
        metric: torch.Tensor,
        r: int,
        list_: list,
        class_token: bool = False,
        distill_token: bool = False):
    
    
    protected = 0
    if class_token:
        protected += 1
    if distill_token:
        protected += 1

    # We can only reduce by a maximum of 50% tokens
    t = metric.shape[1]
    r = min(r, (t - protected) // 2)

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., ::2, :], metric[..., 1::2, :]
        list_a, list_b = list_[..., ::2], list_[..., 1::2]
        scores = a @ b.transpose(-1, -2)

        if class_token:
            scores[..., 0, :] = -math.inf
        if distill_token:
            scores[..., :, 0] = -math.inf

        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        
        list_unm = list_a.gather(dim=-1, index=unm_idx[:, :, 0])
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        list_src = list_a.gather(dim=-1, index=src_idx[:, :, 0])
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        if class_token:
            # Sort to ensure the class token is at the start
            unm_idx = unm_idx.sort(dim=1)[0]

    def merge(x: torch.Tensor, mode="mean") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
        #dst = torch.scatter_reduce(dst, -2, dst_idx.expand(n, r, c), src, reduce=mode)

        if distill_token:
            return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
        else:
            return torch.cat([unm, dst], dim=1), torch.cat([list_unm, list_b], dim=1), list_src, src

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        n, _, c = unm.shape

        src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))

        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., 1::2, :] = dst
        out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)

        return out

    return merge, unmerge


def merge_wavg(
        merge: Callable,
        x: torch.Tensor,
        size: torch.Tensor = None):
    """
    Applies the merge function by taking a weighted average based on token size.
    Returns the merged tensor and the new token sizes.
    """
    if size is None:
        size = torch.ones_like(x[..., 0, None])

    #x, list_now, list_src, src = merge(x * size, mode="mean")
    #size, _, _, _ = merge(size, mode="mean")
    
    x, list_now, list_src, src = merge(x * size, mode='amax')
    size, _, _, _ = merge(size, mode='amax')

    x = x / size
    return x, list_now, list_src, src


def tome(
        x: torch.Tensor,
        metric: torch.Tensor,
        r: int,
        list_: list,
        cls_token: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    if r > 0:
        
        # Appy ToMe here
        merge, _ = bipartite_soft_matching(
            metric,
            r,
            list_,
            cls_token,
            False
        )

        x, list_now, list_src, src = merge_wavg(merge, x, None)

    return x, list_now, list_src, src