import torch
from typing import Tuple, Callable


def do_nothing(x: torch.Tensor, mode: str = None):
    return x


def mps_gather_workaround(input, dim, index):
    if input.shape[-1] == 1:
        return torch.gather(
            input.unsqueeze(-1),
            dim - 1 if dim < 0 else dim,
            index.unsqueeze(-1)
        ).squeeze(-1)
    else:
        return torch.gather(input, dim, index)


def bipartite_soft_matching_random2d_(metric: torch.Tensor,
                                      w: int, h: int, sx: int, sy: int, r: int,
                                      no_rand: bool = False,
                                      generator: torch.Generator = None) -> Tuple[Callable, Callable]:
    """
    Partitions the tokens into src and dst and merges r tokens from src to dst.
    Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.

    Args:
     - metric [B, N, C]: metric to use for similarity
     - w: image width in tokens
     - h: image height in tokens
     - sx: stride in the x dimension for dst, must divide w
     - sy: stride in the y dimension for dst, must divide h
     - r: number of tokens to remove (by merging)
     - no_rand: if true, disable randomness (use top left corner only)
     - rand_seed: if no_rand is false, and if not None, sets random seed.
    """
    B, N, _ = metric.shape

    if r <= 0:
        return do_nothing, do_nothing

    gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather

    with torch.no_grad():
        hsy, wsx = h // sy, w // sx

        # For each sy by sx kernel, randomly assign one token to be dst and the rest src
        if no_rand:
            rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
        else:
            rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(
                metric.device)

        # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
        idx_buffer_view = torch.zeros(hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64)
        idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
        idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)

        # Image is not divisible by sx or sy so we need to move it into a new buffer
        if (hsy * sy) < h or (wsx * sx) < w:
            idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
            idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
        else:
            idx_buffer = idx_buffer_view

        # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
        rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)

        # We're finished with these
        del idx_buffer, idx_buffer_view

        # rand_idx is currently dst|src, so split them
        num_dst = hsy * wsx
        a_idx = rand_idx[:, num_dst:, :]  # src
        b_idx = rand_idx[:, :num_dst, :]  # dst

        def split(x):
            C = x.shape[-1]
            src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
            dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
            return src, dst

        # Cosine similarity between A and B
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = split(metric)
        scores = a @ b.transpose(-1, -2)

        # Can't reduce more than the # tokens in src
        r = min(a.shape[1], r)

        # Find the most similar greedily
        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = split(x)
        n, t1, c = src.shape

        unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        return torch.cat([unm, dst], dim=1)

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        _, _, c = unm.shape

        src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))

        # Combine back to the original shape
        out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
        out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
        out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)

        return out

    return merge, unmerge


def bipartite_soft_matching_random2d(metric: torch.Tensor,
                                     w: int, h: int, sx: int, sy: int, r: int,
                                     mask: torch.Tensor,
                                     no_rand: bool = False,
                                     generator: torch.Generator = None, current=None,
                                     ratio = 0.8) -> Tuple[Callable, Callable]:
    """
    Partitions the tokens into src and dst and merges r tokens from src to dst.
    Dst tokens are partitioned by choosing one randomly in each (sx, sy) region.
    Positions where mask == 1 are forced to be dst tokens and cannot have tokens merged into them.

    Args:
     - metric [B, N, C]: metric to use for similarity
     - w: image width in tokens
     - h: image height in tokens
     - sx: stride in the x dimension for dst, must divide w
     - sy: stride in the y dimension for dst, must divide h
     - r: number of tokens to remove (by merging)
     - mask [B, N, 1]: positions where mask == 1 are forced to be dst tokens
     - no_rand: if true, disable randomness (use top left corner only)
     - generator: random number generator
    """
    B, N, _ = metric.shape

    if r <= 0:
        return do_nothing, do_nothing

    gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather

    with torch.no_grad():
        hsy, wsx = h // sy, w // sx

        # Initialize idx_buffer with zeros
        idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)

        # Reshape mask to match spatial dimensions
        mask_reshaped = mask.view(B, h, w)  # [B, h, w]
        mask_flat = mask.view(B, N, 1)  # [B, N, 1]

        # For each region, select dst tokens
        if True:
            rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
        else:
            rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=metric.device, generator=generator).to(
                metric.device)

        # Prepare idx_buffer_view
        idx_buffer_view = torch.zeros(hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64)
        idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))

        # Reshape idx_buffer_view to match spatial dimensions
        idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)

        # Place idx_buffer_view into idx_buffer
        idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view

        # Set positions where mask == 1 to -1 (dst tokens)
        idx_buffer = idx_buffer.reshape(-1)  # [h * w]
        idx_buffer = idx_buffer.unsqueeze(0).expand(B, -1)  # [B, N]
        idx_buffer = idx_buffer.clone()  # To avoid modifying original tensor

        # Set masked positions to -1
        idx_buffer[mask_flat.view(B, N).bool()] = -1

        # Sort idx_buffer to get rand_idx
        rand_idx = idx_buffer.argsort(dim=1).unsqueeze(-1)  # [B, N, 1]

        # We're finished with these
        del idx_buffer, idx_buffer_view

        # Compute number of dst tokens
        num_masked_dst = mask_flat.view(B, N).sum(dim=1)[0].item()
        num_random_dst = hsy * wsx
        num_dst = num_masked_dst + num_random_dst
        if (N - num_dst) < 30:
            return do_nothing, do_nothing
        else:
            # ratio = current['layer_idx'] / 12 * 0.8
            # ratio = 0.8
            r = int((N - num_dst) * ratio)
        print(f"num_dst: {num_dst}, r: {r}, src: {N - num_dst}, ratio: {ratio}")

        # Split rand_idx into src and dst indices
        a_idx = rand_idx[:, num_dst:, :]  # src indices
        b_idx = rand_idx[:, :num_dst, :]  # dst indices

        # Gather dst_mask to identify masked dst tokens
        dst_mask = mask_flat.gather(dim=1, index=b_idx)  # [B, num_dst, 1]
        dst_mask = dst_mask.squeeze(-1)  # [B, num_dst]

        def split(x):
            C = x.shape[-1]
            src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
            dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
            return src, dst

        # Normalize metric for cosine similarity
        metric_norm = metric / metric.norm(dim=-1, keepdim=True)
        a, b = split(metric_norm)
        if current['topk_indices'] is not None:
            scores = a[..., current['topk_indices']] @ b[..., current['topk_indices']].transpose(-1, -2)
        else:
            # print(torch.topk(a.var(dim=(0, 1)), k=500))
            # 计算方差并选择前k个维度
            topk_indices = torch.topk(a.var(dim=(0, 1)), k=50).indices
            current['topk_indices'] = topk_indices
            # 使用选择的k个维度计算相似度
            scores = a[..., topk_indices] @ b[..., topk_indices].transpose(-1, -2)

        # Can't reduce more than the # tokens in src
        r = min(a.shape[1], r)

        # Find the most similar greedily
        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)

        def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
            src, dst = split(x)
            n, t1, c = src.shape

            unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
            src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
            # dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
            dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), torch.zeros_like(src), reduce="sum")

            return torch.cat([unm, dst], dim=1)

        def unmerge(x: torch.Tensor) -> torch.Tensor:
            unm_len = unm_idx.shape[1]
            unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
            _, _, c = unm.shape

            src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))

            # Combine back to the original shape
            out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
            out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
            out.scatter_(dim=-2,
                         index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c),
                         src=unm)
            out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c),
                         src=src)

            return out

        return merge, unmerge