from typing import Optional, Tuple

import numpy as np
import torch
from transformers import DynamicCache


def apply_dscr_to_cache(
    cache: DynamicCache,
    input_ids: torch.Tensor,
    depth: np.ndarray,
    image_token_id: int,
    image_grid_thw: torch.Tensor,
    merge_size: int,
    alpha: float,
    beta: float,
    sigma: float,
    start_layer_idx: int,
    end_layer_idx: int,
    key_only: bool,
    value_only: bool,
    key_value: bool,
    dscr_lambda: float,
    self_keep: float,
) -> Tuple[DynamicCache, int, int]:
    dscr_lambda = float(dscr_lambda)
    dscr_lambda = max(0.0, min(1.0, dscr_lambda))

    self_keep = float(self_keep)
    self_keep = max(0.0, min(1.0, self_keep))

    depth_tensor = torch.tensor(depth, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

    positions = torch.nonzero(input_ids[0] == image_token_id, as_tuple=False).squeeze(-1)
    if positions.numel() == 0:
        raise RuntimeError("No image tokens found in input_ids.")
    cut_idx = int(positions[0].item())
    num_image_tokens = positions.shape[0]

    grid_t, grid_h, grid_w = image_grid_thw[0].tolist()
    h_tokens = grid_h // merge_size
    w_tokens = grid_w // merge_size
    if h_tokens * w_tokens != num_image_tokens:
        raise RuntimeError(
            f"Token grid mismatch. H_tokens * W_tokens = {h_tokens * w_tokens}, "
            f"num_image_tokens = {num_image_tokens}"
        )

    depth_tensor = torch.nn.functional.interpolate(
        depth_tensor,
        size=(h_tokens, w_tokens),
        mode="bilinear",
        align_corners=False,
    )
    depth_patch = torch.clamp(1.0e-6 + 1.0 / depth_tensor, 0.001, 1000).view(1, -1)
    d_min = depth_patch.min()
    d_max = depth_patch.max()
    depth_patch = (depth_patch - d_min) / (d_max - d_min + 1e-6)

    depth_diff = depth_patch - depth_patch.transpose(1, 0)
    gw_depth = torch.exp(- (depth_diff ** 2) / (2 * sigma ** 2))

    xs = torch.arange(h_tokens, dtype=torch.float32)
    ys = torch.arange(w_tokens, dtype=torch.float32)
    yy, xx = torch.meshgrid(xs, ys, indexing="ij")
    pos = torch.stack([yy, xx], dim=-1).view(-1, 2)
    pos[:, 0] /= max(h_tokens - 1, 1)
    pos[:, 1] /= max(w_tokens - 1, 1)

    pos_diff = torch.cdist(pos, pos, p=2)
    gw_pos = torch.exp(- (pos_diff ** 2) / (2 * sigma ** 2))

    gw = (gw_depth ** alpha) * (gw_pos ** beta)

    t_img = gw.shape[0]
    eye = torch.eye(t_img, device=gw.device, dtype=gw.dtype)
    if self_keep < 1.0:
        gw = gw * (1.0 - eye) + self_keep * eye

    row_sums = gw.sum(dim=-1, keepdim=True)
    zero_mask = row_sums < 1e-12
    if zero_mask.any():
        gw[zero_mask.squeeze(-1)] = 1.0
        row_sums = gw.sum(dim=-1, keepdim=True)

    D = gw / (row_sums + 1e-6)

    legacy = cache.to_legacy_cache()
    total_k = [kv[0] for kv in legacy]
    total_v = [kv[1] for kv in legacy]

    if end_layer_idx is None:
        end_layer_idx = len(total_k)

    # Process each layer individually to support device_map (layers on
    # different GPUs).  D is moved to each layer's device on the fly.
    D_base = D.unsqueeze(0)  # [1, T, T]  — batch dim will be added per-layer

    for li in range(start_layer_idx, end_layer_idx):
        layer_device = total_k[li].device
        layer_dtype = total_k[li].dtype
        D_li = D_base.to(device=layer_device, dtype=layer_dtype)  # [1, T, T]

        if key_only or key_value:
            k = total_k[li]                               # [1, heads, seq, dim]
            seg = k[:, :, cut_idx:cut_idx + num_image_tokens, :]
            ref = torch.matmul(D_li.unsqueeze(0), seg)    # broadcast over batch & heads
            mix = (1.0 - dscr_lambda) * seg + dscr_lambda * ref
            total_k[li] = torch.cat(
                [k[:, :, :cut_idx, :], mix, k[:, :, cut_idx + num_image_tokens:, :]], dim=2
            )

        if value_only or key_value:
            v = total_v[li]
            seg = v[:, :, cut_idx:cut_idx + num_image_tokens, :]
            ref = torch.matmul(D_li.unsqueeze(0), seg)
            mix = (1.0 - dscr_lambda) * seg + dscr_lambda * ref
            total_v[li] = torch.cat(
                [v[:, :, :cut_idx, :], mix, v[:, :, cut_idx + num_image_tokens:, :]], dim=2
            )

    new_legacy = tuple((k, v) for k, v in zip(total_k, total_v))
    new_cache = DynamicCache.from_legacy_cache(new_legacy)
    return new_cache, cut_idx, num_image_tokens
