import torch
from diffusers.utils import export_to_video
from einops import rearrange
import torch.nn.functional as F
import os
from decord import VideoReader, bridge
import math

bridge.set_bridge("torch")


def load_video(video_path, indices):
    vr = VideoReader(video_path)
    frames = vr.get_batch(indices).permute(0, 3, 1, 2) / 255.0

    del vr

    return frames


def center_and_resize(img: torch.Tensor, bbox, resolution, pad_value=1.0):
    top, left, bottom, right = map(int, bbox)
    H_out, W_out = map(int, resolution)

    # --- Normalize input to NCHW float for interpolation ---

    C, one, H, W = img.shape
    assert one == 1, "Second dim must be 1 if img is 4D."
    squeeze_dim = True
    img_nchw = img.unsqueeze(0).to(
        torch.float32
    )  # (1,C,1,H,W) -> not good for interpolate
    img_nchw = img_nchw.squeeze(2)  # -> (1,C,H,W)

    # --- Crop to bbox ---
    # Note: bottom/right are exclusive, so slicing is correct as-is
    img_crop = img_nchw[..., top:bottom, left:right]  # (1,C,h,w)
    _, _, h, w = img_crop.shape
    if h <= 0 or w <= 0:
        return torch.zeros_like(img), False
        

    # --- Compute isotropic scale and new size ---
    scale = min(H_out / h, W_out / w)
    new_h = max(1, int(math.floor(h * scale)))
    new_w = max(1, int(math.floor(w * scale)))

    # --- Resize (bilinear for image) ---
    img_resized = F.interpolate(
        img_crop, size=(new_h, new_w), mode="bilinear", align_corners=False
    )

    # --- Compute padding to center ---
    pad_top = (H_out - new_h) // 2
    pad_bottom = H_out - new_h - pad_top
    pad_left = (W_out - new_w) // 2
    pad_right = W_out - new_w - pad_left

    # --- Pad ---
    out = F.pad(
        img_resized, (pad_left, pad_right, pad_top, pad_bottom), value=float(pad_value)
    )  # (1,C,H_out,W_out)

    out = out.unsqueeze(2)  # (1,C,1,H_out,W_out)
    out = out.squeeze(0)  # -> (C,1,H_out,W_out)

    return out, True


def bbox_mask(m: torch.Tensor, threshold: float = 0.5, margin: int = 0):
    assert m.ndim == 4, f"expected (C,F,H,W), got {tuple(m.shape)}"
    C, F, H, W = m.shape
    device, dtype = m.device, m.dtype

    # Binarize using channel 0
    mb0 = m[0] if dtype == torch.bool else (m[0] > threshold)  # (F, H, W)

    # Row/col occupancy per frame
    row_any = mb0.any(dim=-1)  # (F, H)
    col_any = mb0.any(dim=-2)  # (F, W)
    valid = row_any.any(dim=1) & col_any.any(dim=1)  # (F,)

    # Vectorized min/max index per frame
    iy = torch.arange(H, device=device).view(1, H)  # (1, H)
    ix = torch.arange(W, device=device).view(1, W)  # (1, W)

    top = torch.where(row_any, iy, torch.full_like(iy, H)).min(dim=1).values
    bottom = torch.where(row_any, iy + 1, torch.full_like(iy, -1)).max(dim=1).values
    left = torch.where(col_any, ix, torch.full_like(ix, W)).min(dim=1).values
    right = torch.where(col_any, ix + 1, torch.full_like(ix, -1)).max(dim=1).values

    # Margin + clamp
    top = (top - margin).clamp(min=0, max=H)
    bottom = (bottom + margin).clamp(min=0, max=H)
    left = (left - margin).clamp(min=0, max=W)
    right = (right + margin).clamp(min=0, max=W)

    # Zero-out invalid frames
    zero = torch.zeros_like(top)
    top = torch.where(valid, top, zero)
    bottom = torch.where(valid, bottom, zero)
    left = torch.where(valid, left, zero)
    right = torch.where(valid, right, zero)

    # Build per-frame rectangle mask (F, H, W)
    I = torch.arange(H, device=device).view(1, H, 1)
    J = torch.arange(W, device=device).view(1, 1, W)
    rectF = (
        (I >= top[:, None, None])
        & (I < bottom[:, None, None])
        & (J >= left[:, None, None])
        & (J < right[:, None, None])
        & valid[:, None, None]
    )  # (F,H,W)

    # Broadcast to channels and cast to input dtype
    out_bool = rectF.unsqueeze(0).expand(C, F, H, W)
    out = out_bool if dtype == torch.bool else out_bool.to(dtype)

    # Boxes repeated across channels
    boxes_f = torch.stack([top, left, bottom, right], dim=1).to(torch.int64)  # (F,4)
    boxes = boxes_f.unsqueeze(0).expand(C, F, 4).contiguous()

    return out, boxes




@torch.no_grad()
def scale_video_mask_CFHW(
    mask: torch.Tensor,               # (C=1, F, H, W) or (1, F, H, W)
    z1, z2,                            # float or 1D tensor of length F 
    binary: bool = True,
    thresh: float = 0.5,
    padding_mode: str = "zeros",       # "zeros" | "border" | "reflection"
):
    assert mask.dim() == 4 and mask.shape[0] == 1, "mask must be (1,F,H,W)"
    C, Fnum, H, W = mask.shape
    device = mask.device

    def _to_vec(x, name):
        if x is None:
            return None
        if isinstance(x, (int, float)):
            return torch.full((Fnum,), float(x), device=device, dtype=torch.float32)
        x = torch.as_tensor(x, device=device, dtype=torch.float32)
        if x.ndim == 0:
            x = x.expand(Fnum)
        assert x.shape == (Fnum,), f"{name} must be length-F or scalar"
        return x

    s = z1 / (z2 + 1e-8)                 # (F,)

    inv_s = 1.0 / torch.clamp(s, min=1e-8)     # use inv in theta (output->input mapping)

    # Prepare NCHW by making frames the batch:
    x = mask.permute(1, 0, 2, 3).float()       # (F,1,H,W)

    # Build per-frame affine matrices (F,2,3)
    theta = torch.zeros((Fnum, 2, 3), device=device, dtype=torch.float32)
    theta[:, 0, 0] = inv_s
    theta[:, 1, 1] = inv_s
    # center-preserving zoom; no translation -> zeros elsewhere

    grid = F.affine_grid(theta, size=x.shape, align_corners=False)
    out = F.grid_sample(x, grid, mode="nearest" if binary else "bilinear",
                        padding_mode=padding_mode, align_corners=False)

    out = out.permute(1, 0, 2, 3)  # back to (1,F,H,W)

    if binary:
        out = (out >= thresh).to(mask.dtype)
    else:
        out = out.to(mask.dtype)
    return out






def shift_mask(src_masks, tgt_masks):
    try:
        src_y, src_x = (src_masks[0, 0, :, :] > 0.5).nonzero(as_tuple=True)

        src_y, src_x = src_y.float().mean(), src_x.float().mean()

        tgt_y, tgt_x = (tgt_masks[0, 0, :, :] > 0.5).nonzero(as_tuple=True)

        tgt_y, tgt_x = tgt_y.float().mean(), tgt_x.float().mean()

        shift_x = int(tgt_x - src_x)
        shift_y = int(tgt_y - src_y)
        shifted_src_masks = []
        for f in range(src_masks.shape[1]):
            rolled = torch.roll(src_masks[:, f], shifts=(shift_y, shift_x), dims=(1, 2))
            shifted_src_masks.append(rolled)

        shifted_src_masks = torch.stack(shifted_src_masks, dim=1)

        return shifted_src_masks, True
    except Exception as e:

        return src_masks, False




def dilate_mask(mask, radius=10):
    c, f, h, w = mask.shape

    k = 2 * radius + 1

    mask = rearrange(mask, "c f h w -> (c f) h w")

    dilated_mask = F.max_pool2d(mask, kernel_size=k, stride=1, padding=radius)

    dilated_mask = (dilated_mask > 0).to(mask.dtype)

    dilated_mask = rearrange(dilated_mask, "(c f) h w -> c f h w", c=c, f=f)

    return dilated_mask


def shift_dilate_mask(src_masks, tgt_masks):
    try:
        src_y, src_x = (src_masks[0, 0, :, :] > 0.5).nonzero(as_tuple=True)

        src_y, src_x = src_y.float().mean(), src_x.float().mean()

        tgt_y, tgt_x = (tgt_masks[0, 0, :, :] > 0.5).nonzero(as_tuple=True)

        tgt_y, tgt_x = tgt_y.float().mean(), tgt_x.float().mean()

        shift_x = int(tgt_x - src_x)
        shift_y = int(tgt_y - src_y)
        shifted_src_masks = []
        for f in range(src_masks.shape[1]):
            rolled = torch.roll(src_masks[:, f], shifts=(shift_y, shift_x), dims=(1, 2))
            shifted_src_masks.append(rolled)

        shifted_src_masks = torch.stack(shifted_src_masks, dim=1)

        dilated_mask = dilate_mask(shifted_src_masks)
    except Exception as e:
        print(
            f"Error found in dilating: tgt_x:{tgt_x}, src_x:{src_x}, tgt_y:{tgt_y},  src_y:{src_y}"
        )

        return src_masks

    return dilated_mask


def save_video(x, path,fps=8):
    x = x.permute(1, 2, 3, 0).float().cpu().numpy().clip(-1, 1)
    x = (x + 1) / 2
    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path), exist_ok=True)
    export_to_video(x, path, fps=fps)(base) ruanpenghui@ai-112-189:/ms/AIGC/ruanpenghui/VACE-main$ 