import math
from typing import Iterable, Optional, Sequence, Tuple

import torch
import torch.nn.functional as F


def _normalize_depth(depth: torch.Tensor) -> torch.Tensor:
    depth = depth.float()
    depth_min = depth.min()
    depth_max = depth.max()
    denom = (depth_max - depth_min).clamp_min(1.0e-6)
    return (depth - depth_min) / denom


def _prepare_depth_grid(depth: torch.Tensor, grid_size: int) -> torch.Tensor:
    if depth.dim() == 2:
        depth = depth.unsqueeze(0).unsqueeze(0)
    elif depth.dim() == 3:
        depth = depth.unsqueeze(0)
    elif depth.dim() != 4:
        raise ValueError(f"Unexpected depth shape: {depth.shape}")

    depth = F.interpolate(depth, size=(grid_size, grid_size), mode="bilinear", align_corners=False)
    return depth.view(1, -1)


def _prepare_depth_vector(depth: torch.Tensor, length: int) -> torch.Tensor:
    if depth.dim() == 2:
        depth = depth.unsqueeze(0).unsqueeze(0)
    elif depth.dim() == 3:
        depth = depth.unsqueeze(0)
    elif depth.dim() != 4:
        raise ValueError(f"Unexpected depth shape: {depth.shape}")

    depth = depth.reshape(depth.shape[0], depth.shape[1], -1)
    depth = F.interpolate(depth, size=length, mode="linear", align_corners=False)
    return depth.view(1, -1)


def _build_position_grid(grid_size: int, device: torch.device) -> torch.Tensor:
    coords = torch.tensor([[i // grid_size, i % grid_size] for i in range(grid_size * grid_size)],
                          dtype=torch.float32, device=device)
    return coords / max(grid_size - 1, 1)


def build_dscr_matrix(
    depth: torch.Tensor,
    image_len: int,
    alpha: float,
    beta: float,
    sigma: float,
    keep_ratio: float,
) -> torch.Tensor:
    grid_size = int(math.sqrt(image_len))
    if grid_size * grid_size == image_len:
        depth_patch = _prepare_depth_grid(depth, grid_size).squeeze(0)
        inv_depth = (1.0e-6 + 1.0 / depth_patch).clamp(0.001, 1000)
        
        dmin = inv_depth.min()
        dmax = inv_depth.max()
        inv_depth = (inv_depth - dmin) / (dmax - dmin + 1e-6)

        depth_diff = torch.abs(inv_depth.unsqueeze(0) - inv_depth.unsqueeze(1))
        gaussian_weight_depth = torch.exp(-(depth_diff ** 2) / (2 * sigma ** 2 + 1.0e-12))

        pixel_positions = torch.tensor(
            [[i // grid_size, i % grid_size] for i in range(image_len)],
            dtype=torch.float32,
            device=depth.device
        )
        pixel_positions = pixel_positions / float(grid_size - 1)
        position_diff = torch.cdist(pixel_positions, pixel_positions, p=2)
        gaussian_weight_position = torch.exp(-(position_diff ** 2) / (2 * sigma ** 2 + 1.0e-12))
    else:
        depth_patch = _prepare_depth_vector(depth, image_len).squeeze(0)
        inv_depth = (1.0e-6 + 1.0 / depth_patch).clamp(0.001, 1000)
        
        dmin = inv_depth.min()
        dmax = inv_depth.max()
        inv_depth = (inv_depth - dmin) / (dmax - dmin + 1e-6)
        
        depth_diff = torch.abs(inv_depth.unsqueeze(0) - inv_depth.unsqueeze(1))
        gaussian_weight_depth = torch.exp(-(depth_diff ** 2) / (2 * sigma ** 2 + 1.0e-12))

        idx = torch.linspace(0, 1, image_len, device=depth.device).unsqueeze(1)
        position_diff = torch.cdist(idx, idx, p=2)
        gaussian_weight_position = torch.exp(-(position_diff ** 2) / (2 * sigma ** 2 + 1.0e-12))

    gaussian_weight = (gaussian_weight_depth ** alpha) * (gaussian_weight_position ** beta)

    if keep_ratio is not None and keep_ratio < 1.0:
        eye = torch.eye(image_len, dtype=gaussian_weight.dtype, device=gaussian_weight.device)
        gaussian_weight = gaussian_weight * (1.0 - eye) + keep_ratio * eye

    D = gaussian_weight / (gaussian_weight.sum(dim=-1, keepdim=True) + 1e-6)
    
    # Debug: print first sample stats
    import sys
    if not hasattr(sys, '_dscr_debug_printed'):
        print(f"[DSCR DEBUG] D shape: {D.shape}")
        print(f"[DSCR DEBUG] D min/max/mean: {D.min():.6f} / {D.max():.6f} / {D.mean():.6f}")
        print(f"[DSCR DEBUG] inv_depth min/max/mean: {inv_depth.min():.6f} / {inv_depth.max():.6f} / {inv_depth.mean():.6f}")
        sys._dscr_debug_printed = True

    return D


def refine_past_key_values(
    past_key_values: Sequence[Tuple[torch.Tensor, torch.Tensor]],
    depth: torch.Tensor,
    image_start: int,
    image_len: int,
    alpha: float,
    beta: float,
    sigma: float,
    keep_ratio: float,
    dscr_lambda: float,
    start_layer: int,
    end_layer: Optional[int],
    key_only: bool,
    value_only: bool,
    key_value: bool,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
    if end_layer is None:
        end_layer = len(past_key_values)

    key0 = past_key_values[0][0]
    seq_dim = None
    if key0.shape[2] >= image_start + image_len:
        seq_dim = 2
        seq_len = key0.shape[2]
    elif key0.shape[1] >= image_start + image_len:
        seq_dim = 1
        seq_len = key0.shape[1]
    else:
        return tuple(past_key_values)
    if image_start < 0 or image_start >= seq_len:
        return tuple(past_key_values)
    max_len = seq_len - image_start
    if image_len > max_len:
        image_len = max_len
    if image_len <= 0:
        return tuple(past_key_values)

    D_cpu = build_dscr_matrix(depth, image_len, alpha, beta, sigma, keep_ratio).float().cpu()

    updated = []
    D_cache = {}
    for layer_idx, (key, value) in enumerate(past_key_values):
        if not (start_layer <= layer_idx < end_layer):
            updated.append((key, value))
            continue

        if seq_dim == 1:
            key_work = key.permute(0, 2, 1, 3).contiguous()
            value_work = value.permute(0, 2, 1, 3).contiguous()
        else:
            key_work = key
            value_work = value

        dev = key_work.device
        dt = key_work.dtype
        cache_key = (dev, dt)
        if cache_key not in D_cache:
            D_cache[cache_key] = D_cpu.to(device=dev, dtype=dt)
        D = D_cache[cache_key]

        if key_only or key_value:
            key_img = key_work[:, :, image_start:image_start + image_len, :]
            refined = torch.einsum("ij,bhjd->bhid", D, key_img)
            if dscr_lambda is None or dscr_lambda >= 1.0:
                key_work = torch.cat(
                    (key_work[:, :, :image_start, :], refined, key_work[:, :, image_start + image_len:, :]),
                    dim=2
                )
            else:
                mixed = (1 - dscr_lambda) * key_img + dscr_lambda * refined
                key_work = torch.cat(
                    (key_work[:, :, :image_start, :], mixed, key_work[:, :, image_start + image_len:, :]),
                    dim=2
                )

        if value_only or key_value:
            value_img = value_work[:, :, image_start:image_start + image_len, :]
            refined = torch.einsum("ij,bhjd->bhid", D, value_img)
            if dscr_lambda is None or dscr_lambda >= 1.0:
                value_work = torch.cat(
                    (value_work[:, :, :image_start, :], refined, value_work[:, :, image_start + image_len:, :]),
                    dim=2
                )
            else:
                mixed = (1 - dscr_lambda) * value_img + dscr_lambda * refined
                value_work = torch.cat(
                    (value_work[:, :, :image_start, :], mixed, value_work[:, :, image_start + image_len:, :]),
                    dim=2
                )

        if seq_dim == 1:
            updated.append(
                (
                    key_work.permute(0, 2, 1, 3).contiguous(),
                    value_work.permute(0, 2, 1, 3).contiguous(),
                )
            )
        else:
            updated.append((key_work, value_work))

    return tuple(updated)
