import os, cv2, torch, numpy as np
from PIL import Image
from torchvision.transforms import (Compose, Resize, CenterCrop, ToTensor, Normalize)

def safe_crop_expand(frame: np.ndarray,
                     box,
                     min_size: int = 3):
    w, h = frame.size
    x1, y1, x2, y2 = map(int, box)

    x1 = np.clip(x1, 0, w - 1)
    x2 = np.clip(x2, 0, w)
    y1 = np.clip(y1, 0, h - 1)
    y2 = np.clip(y2, 0, h)

    if x2 - x1 < min_size:
        need = min_size - (x2 - x1)
        left  = min(need // 2 + need % 2, x1)
        right = min(need // 2,           w - x2)
        x1 -= left
        x2 += need - left

    if y2 - y1 < min_size:
        need = min_size - (y2 - y1)
        up   = min(need // 2 + need % 2, y1)
        down = min(need // 2,           h - y2)
        y1 -= up
        y2 += need - up

    x1 = np.clip(x1, 0, w - 1)
    x2 = np.clip(x2, x1 + 1, w)
    y1 = np.clip(y1, 0, h - 1)
    y2 = np.clip(y2, y1 + 1, h)

    return frame.crop((x1, y1, x2, y2))  # RGB


######## use dinov3
dinov3 = None
if not hasattr(torch, "compiler"):
    torch.compiler = types.SimpleNamespace()
if not hasattr(torch.compiler, "is_compiling"):
    torch.compiler.is_compiling = lambda: False
from modelscope import AutoImageProcessor, AutoModel
from transformers.image_utils import load_image as load_image_hf

def init_dinov3_model():
    global dinov3
    # Load model, ray sampler, datasets
    if dinov3 is not None:
        return

    model_dir = "facebook/dinov3-vith16plus-pretrain-lvd1689m"

    processor = AutoImageProcessor.from_pretrained(
        model_dir
    )
    model = AutoModel.from_pretrained(
        model_dir, device_map='auto'
    )
    dinov3 = model, processor

import pickle
import torch.nn as nn

def stability_metric(
    img_dir: str,
    boxes: list[tuple[int, tuple[int, int, int, int]]],
    label: str | None = None,
    embed_dir='',
    weights=(0.5, 0.5, 0),
):

    global dinov3
    init_dinov3_model()
    model, processor = dinov3

    boxes = sorted(boxes, key=lambda x: x[0])
    w_R, w_A, w_S = weights
    if label is None:
        w_R, w_A = w_R / (w_R + w_A), w_A / (w_R + w_A)
        w_S = 0.0

    if not os.path.exists(embed_dir):

        embs_dino = []
        crops = []
        for fid, box in boxes:
            img_path = os.path.join(img_dir, f"{fid:05}.png")
            if not os.path.exists(img_path):
                img_path = img_path.replace('.png', '.jpg')
            # frame = cv2.imread(img_path)
            frame = Image.open(img_path)
            if frame is None:
                raise FileNotFoundError(img_path)
            crop = safe_crop_expand(frame, box, min_size=32)   # PIL RGB
            crops.append(crop)
            # embs_dino.append(dino_embed(crop))                 # 1024-d for R/A
            image_input = load_image_hf(crop)
            inputs = processor(images=image_input, return_tensors="pt").to(model.device)
            with torch.inference_mode():
                outputs = model(**inputs)
            scene_feat = outputs.pooler_output.squeeze(0)
            embs_dino.append(scene_feat)
        embs_dino = torch.stack(embs_dino)                     # (T,1024)
        print(f'store to: {embed_dir}')
        with open(embed_dir, 'wb') as f:
            # Pickle the dictionary and write it to the file
            pickle.dump(embs_dino.cpu().numpy(), f, protocol=pickle.HIGHEST_PROTOCOL)
    else:

        with open(embed_dir, 'rb') as f:
            embs_dino = torch.from_numpy(pickle.load(f)).cuda()


    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    sims_ref = cos(embs_dino, embs_dino[0:1]).clamp_min(0).cpu().numpy()
    R = float(sims_ref[1:].mean())

    sims_adj = cos(embs_dino[1:], embs_dino[:-1]).clamp_min(0).cpu().numpy()
    A = float(sims_adj.mean())

    if w_S > 0:
        pass
    else:
        RS = 0.0

    score = w_R * R + w_A * A + w_S * RS

    return {"R": R, "A": A, "S": RS, "score": score}

# ─────────────────────────── Demo ───────────────────────────
if __name__ == "__main__":
    boxes = [
        (0,  (120, 80, 260, 220)),
        (3,  (122, 79, 259, 219)),
        (7,  (125, 82, 260, 220)),
    ]
    res = stability_metric("/path/to/frames", boxes, label="pedestrian")
    print(res)
