# -*- coding: utf-8 -*-
import os
import sys
import types
import pickle
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
from typing import List, Tuple, Dict, Optional

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# ======== Globals ========
smooth_model = None
flow_model = None
flow_args = None
subjective_quality_model = None
scenedino_model = None
dinov3 = None
TTA = True

# ======== Local metrics (package-relative) ========
from .metrics.lpips_metric import LearnedPerceptualImagePatchSimilarityMetric
from .metrics.ssim_metric import StructuralSimilarityIndexMeasureMetric
from .p2020 import single_frame_metrics, video_metrics
from .p2020_v2 import single_frame_metrics as single_frame_metrics_v2
from .p2020_v2 import video_metrics as video_metrics_v2


# ----------------------------- Utils -----------------------------
def _device() -> torch.device:
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


class InputPadder:
    """Pads tensors (B,C,H,W) so that H,W are divisible by `divisor`."""
    def __init__(self, dims, divisor=16):
        self.ht, self.wd = dims[-2:]
        pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
        pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
        self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]

    def pad(self, *inputs):
        return [F.pad(x, self._pad, mode='replicate') for x in inputs]

    def unpad(self, x):
        ht, wd = x.shape[-2:]
        c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
        return x[..., c[0]:c[1], c[2]:c[3]]


def _imread_rgb(path: str) -> np.ndarray:
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    if img is None:
        raise FileNotFoundError(f"Image not found: {path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def load_image(path: str, device: Optional[torch.device] = None) -> torch.Tensor:
    """Return (1,C,H,W) float32 in [0,1] on device."""
    if device is None:
        device = _device()
    arr = _imread_rgb(path)
    t = torch.tensor(arr, dtype=torch.float32).permute(2, 0, 1) / 255.0
    return t.unsqueeze(0).to(device)


# ----------------------------- VFIMamba (frame interpolation) -----------------------------
def init_smooth_model():
    """Lazy-init VFIMamba."""
    global smooth_model
    if smooth_model is not None:
        return
    if not torch.cuda.is_available():
        raise RuntimeError("VFIMamba requires CUDA GPU.")
    # your local structure:
    sys.path.append('VFIMamba')
    from Trainer_finetune import Model
    import cfg
    cfg.MODEL_CONFIG['LOGNAME'] = 'VFIMamba'
    cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config(F=32, depth=[2, 2, 2, 3, 3])
    model = Model(-1)
    model.load_model(name='VFIMamba.pkl')
    model.eval()
    model.device()  # move to cuda
    smooth_model = model


# ----------------------------- SEA-RAFT (optical flow) -----------------------------
def init_flow_model():
    """Lazy-init SEA-RAFT flow model."""
    global flow_model, flow_args
    if flow_model is not None:
        return
    if not torch.cuda.is_available():
        raise RuntimeError("SEA-RAFT requires CUDA GPU.")
    # NOTE: keep your absolute paths; consider param/env later if needed
    sys.path.append('/mnt/cache/zhouyang/dg-bench/SEA-RAFT')
    sys.path.append('/mnt/cache/zhouyang/dg-bench/SEA-RAFT/core')
    import argparse as _argparse
    from core.raft import RAFT
    from core.utils.utils import load_ckpt
    from config.parser import parse_args

    args = {
        "cfg": "/mnt/cache/zhouyang/dg-bench/SEA-RAFT/config/eval/kitti-M.json",
        "path": "/mnt/cache/zhouyang/dg-bench/Tartan-C-T-TSKH-kitti432x960-M.pth",
    }
    args = _argparse.Namespace(**args)
    args = parse_args(args)

    model = RAFT(args)
    load_ckpt(model, args.path)
    model.to('cuda').eval()

    flow_model = model
    flow_args = args


@torch.inference_mode()
def forward_flow(image1: torch.Tensor, image2: torch.Tensor):
    """image: (1,C,H,W) in [0,1] on CUDA"""
    with torch.amp.autocast(device_type="cuda"):
        out = flow_model(image1, image2, iters=flow_args.iters, test_mode=True)
    flow_final = out['flow'][-1]
    info_final = out['info'][-1]
    return flow_final, info_final


def _compute_flow(image1: torch.Tensor, image2: torch.Tensor) -> np.ndarray:
    """Return flow (H,W,2) float32 in original scale."""
    img1 = F.interpolate(image1, scale_factor=2 ** flow_args.scale,
                         mode='bilinear', align_corners=False)
    img2 = F.interpolate(image2, scale_factor=2 ** flow_args.scale,
                         mode='bilinear', align_corners=False)
    flow, info = forward_flow(img1, img2)
    # downscale back
    s = (0.5 ** flow_args.scale)
    flow_down = F.interpolate(flow, scale_factor=s, mode='bilinear', align_corners=False) * s
    flow_np = flow_down.detach().cpu().numpy().squeeze().transpose(1, 2, 0).astype(np.float32)
    return flow_np


# ----------------------------- Motion helpers -----------------------------
def compute_motion_series(images: List[str]) -> np.ndarray:
    """Median optical flow magnitude per adjacent pair (px/frame)."""
    assert len(images) >= 2
    init_flow_model()
    mags = []
    prev = images[0]
    for nxt in images[1:]:
        img1 = load_image(prev, device=torch.device('cuda'))
        img2 = load_image(nxt, device=torch.device('cuda'))
        flow = _compute_flow(img1, img2)
        mag = np.sqrt(flow[..., 0] ** 2 + flow[..., 1] ** 2)
        mags.append(float(np.median(mag)))
        prev = nxt
    return np.asarray(mags, dtype=np.float32)


def select_indices_by_arc_length_abs(
    mags: np.ndarray,
    v_low: float = 0.4,      # px/frame, nearly static threshold
    v_high: float = 4.0,     # px/frame, very dynamic threshold
    min_k: int = 4,
    max_k: int = 16,
    force_odd_gap: bool = False
) -> List[int]:
    """
    Motion-arc-length based keyframe sampling (no normalization).
    Decide K from mean motion in [min_k, max_k], then sample by equal arc length.
    """
    N = len(mags) + 1
    if N <= 2:
        return list(range(N))

    m_bar = float(mags.mean()) if len(mags) else 0.0
    if v_high <= v_low:
        v_high = v_low + 1e-6
    r = (m_bar - v_low) / (v_high - v_low)
    r = float(np.clip(r, 0.0, 1.0))
    K = int(round(min_k + r * (max_k - min_k)))
    K = int(np.clip(K, min_k, max_k))

    S = np.concatenate([[0.0], np.cumsum(mags)])  # length N
    S_total = S[-1]
    if S_total <= 1e-6:
        # almost still → uniform pick
        return list(np.linspace(0, N - 1, num=min(K, N), dtype=int))

    targets = np.linspace(0.0, S_total, num=min(K, N))
    idxs = [0]
    ptr = 1
    for t in targets[1:-1]:
        while ptr < N and S[ptr] < t:
            ptr += 1
        i = ptr
        if i < N and abs(S[i] - t) > abs(S[i - 1] - t):
            i = i - 1
        idxs.append(int(i))
    idxs.append(N - 1)

    if force_odd_gap and N >= 3 and len(idxs) >= 2:
        adj = [idxs[0]]
        for a, b in zip(idxs[:-1], idxs[1:]):
            if (b - a) % 2 == 0 and (b - a) >= 2:
                if b + 1 < N:
                    b = b + 1
                elif a - 1 >= 0:
                    a = a - 1
            if a <= adj[-1]:
                a = adj[-1]
            if b <= a:
                b = a + 1
            adj[-1] = a
            adj.append(b)
        clean = [adj[0]]
        for x in adj[1:]:
            if x > clean[-1]:
                clean.append(x)
        idxs = clean
    return idxs


def build_pairs_with_mid(idxs: List[int]) -> List[Tuple[int, int, int]]:
    """Construct (i, j, c) with j-i>=2 and c=(i+j)//2."""
    triples = []
    for i, j in zip(idxs[:-1], idxs[1:]):
        if j - i < 2:
            continue
        c = (i + j) // 2
        triples.append((i, j, c))
    return triples


# ----------------------------- Smoothness -----------------------------
@torch.no_grad()
def get_video_smoothness(video_list: List[List[str]]) -> np.ndarray:
    """Return array [mse_like, ssim_like, lpips_like], each ↑ better (0~1)."""
    init_smooth_model()
    init_flow_model()

    scores_ssim, scores_lpips, scores_mse = [], [], []
    print('=== Start Smoothness (flow-weighted) ===')
    for frames in tqdm(video_list):
        even, odd = frames[::2], frames[1::2]
        if len(even) < 2 or len(odd) < 1:
            scores_ssim += [0.0]; scores_lpips += [0.0]; scores_mse += [0.0]
            continue

        ssim_metric = StructuralSimilarityIndexMeasureMetric()
        lpips_metric = LearnedPerceptualImagePatchSimilarityMetric()

        ssim_this, lpips_this, mse_this, weights = [], [], [], []
        for i, (p0, p2) in enumerate(zip(even[:-1], even[1:])):
            p1 = odd[i]
            I0 = load_image(p0, device=torch.device('cuda'))
            I2 = load_image(p2, device=torch.device('cuda'))
            # flow magnitude (avoid padding border effect)
            f0 = load_image(p0, device=torch.device('cuda'))
            f2 = load_image(p2, device=torch.device('cuda'))
            flow = _compute_flow(f0, f2)
            m = float(np.median(np.sqrt(flow[..., 0] ** 2 + flow[..., 1] ** 2)))
            w = float(np.clip(m, 0.0, 10.0) / 10.0)  # 0~1
            weights.append(w)

            padder = InputPadder(I0.shape, divisor=32)
            I0p, I2p = padder.pad(I0, I2)
            mid_pred = padder.unpad(
                smooth_model.inference(I0p, I2p, True, TTA=TTA, fast_TTA=TTA, scale=0.0)
            )[0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0
            mid_pred = mid_pred.astype(np.uint8)

            gt = cv2.imread(p1)  # BGR
            ssim_score = ssim_metric._compute_scores(mid_pred, gt)
            lpips_score = lpips_metric._compute_scores(mid_pred, gt)
            lpips_score = 0.0 if lpips_score > 1 else 1 - lpips_score
            mse_score = (255.0 - np.abs(mid_pred.astype(np.float32) - gt.astype(np.float32)).mean()) / 255.0

            ssim_this.append(ssim_score)
            lpips_this.append(lpips_score)
            mse_this.append(mse_score)

        w = np.asarray(weights, dtype=np.float32)
        if w.sum() <= 1e-8:
            w = np.ones_like(w)  # fallback: equal weight
        w /= w.sum()

        scores_ssim.append(float(np.sum(np.asarray(ssim_this) * w)))
        scores_lpips.append(float(np.sum(np.asarray(lpips_this) * w)))
        scores_mse.append(float(np.sum(np.asarray(mse_this) * w)))

    out = np.array([np.mean(scores_mse), np.mean(scores_ssim), np.mean(scores_lpips)], dtype=np.float32)
    print(f'raw: (mse_like, ssim_like, lpips_like) = {tuple(map(float, out))}')
    return out


@torch.no_grad()
def get_video_smoothness_v2(video_list: List[List[str]]) -> np.ndarray:
    """Motion-driven subsampling, then do (u,u+2)->mid vs GT(u+1). Return [mse_like, ssim_like, lpips_like]."""
    init_smooth_model()

    scores_ssim, scores_lpips, scores_mse = [], [], []
    print('=== Start Smoothness v2 (motion-downsampled, px-scale) ===')
    for frames in tqdm(video_list):
        mags = compute_motion_series(frames) if len(frames) >= 2 else np.array([])
        idxs = select_indices_by_arc_length_abs(mags, v_low=1.0, v_high=10.0, min_k=4, max_k=20, force_odd_gap=False)
        frames_ds = [frames[k] for k in idxs] if idxs else frames
        if len(frames_ds) < 3:
            scores_ssim += [0.0]; scores_lpips += [0.0]; scores_mse += [0.0]
            continue

        triples = [(u, u + 2, u + 1) for u in range(len(frames_ds) - 2)]
        ssim_metric = StructuralSimilarityIndexMeasureMetric()
        lpips_metric = LearnedPerceptualImagePatchSimilarityMetric()

        ssim_this, lpips_this, mse_this = [], [], []
        for i, j, c in triples:
            I0 = load_image(frames_ds[i], device=torch.device('cuda'))
            I2 = load_image(frames_ds[j], device=torch.device('cuda'))
            mid_gt = cv2.imread(frames_ds[c])

            padder = InputPadder(I0.shape, divisor=32)
            I0p, I2p = padder.pad(I0, I2)
            mid_pred = padder.unpad(
                smooth_model.inference(I0p, I2p, True, TTA=TTA, fast_TTA=TTA, scale=0.0)
            )[0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0
            mid_pred = mid_pred.astype(np.uint8)

            ssim_score = ssim_metric._compute_scores(mid_pred, mid_gt)
            lpips_score = lpips_metric._compute_scores(mid_pred, mid_gt)
            lpips_score = 0.0 if lpips_score > 1 else 1 - lpips_score
            mse_score = (255.0 - np.abs(mid_pred.astype(np.float32) - mid_gt.astype(np.float32)).mean()) / 255.0

            ssim_this.append(ssim_score)
            lpips_this.append(lpips_score)
            mse_this.append(mse_score)

        scores_ssim.append(float(np.mean(ssim_this)))
        scores_lpips.append(float(np.mean(lpips_this)))
        scores_mse.append(float(np.mean(mse_this)))

    out = np.array([np.mean(scores_mse), np.mean(scores_ssim), np.mean(scores_lpips)], dtype=np.float32)
    print(f'raw: (mse_like, ssim_like, lpips_like) = {tuple(map(float, out))}')
    return out


# ----------------------------- Motion magnitude -----------------------------
@torch.no_grad()
def get_video_magnitude(video_list: List[List[str]]) -> float:
    """Average median optical-flow magnitude (px/frame) across videos."""
    init_flow_model()
    print('=== Start Magnitude ===')
    flows = []
    for images in tqdm(video_list):
        if len(images) < 2:
            flows.append(0.0)
            continue
        scores_this = []
        for a, b in zip(images[:-1], images[1:]):
            img1 = load_image(a, device=torch.device('cuda'))
            img2 = load_image(b, device=torch.device('cuda'))
            flow = _compute_flow(img1, img2)
            median_flow = float(np.median(np.sqrt(flow[..., 0] ** 2 + flow[..., 1] ** 2)))
            scores_this.append(median_flow)
        flows.append(float(np.mean(scores_this)) if scores_this else 0.0)
    mean_flow = float(np.mean(flows)) if flows else 0.0
    print(f'raw: {mean_flow}')
    print('=== Done ===')
    return mean_flow


# ----------------------------- Objective quality -----------------------------
@torch.no_grad()
def get_objective_quality(video_list: List[List[str]]) -> Dict[str, float]:
    """Aggregate frame-level + video-level proxies, return dict of mean scores."""
    val_dict: Dict[str, List[float]] = {}
    video_val_dict: Dict[str, List[float]] = {}
    print('=== Start Objective Quality ===')
    for images in tqdm(video_list):
        imgs = [cv2.imread(f) for f in images]
        adj_metric = video_metrics(imgs)
        for k, v in adj_metric.items():
            video_val_dict.setdefault(k, []).append(v)

        per_frame: Dict[str, List[float]] = defaultdict(list)
        for im in imgs:
            fm = single_frame_metrics(im)
            for k, v in fm.items():
                per_frame[k].append(v)
        for k, v in per_frame.items():
            val_dict.setdefault(k, []).append(float(np.nanmean(np.asarray(v))))

    normed: Dict[str, float] = {}
    for k, v in val_dict.items():
        score = float(np.nanmean(np.asarray(v)))
        normed[k] = score
        print(f'{k} score mean: {score:.6f}')
    # video-level（如有需要可合并）
    # for k, v in video_val_dict.items():
    #     score = float(np.nanmean(np.asarray(v)))
    #     normed[k] = score
    #     print(f'video {k} score mean: {score:.6f}')

    # 可按需剔除不想纳入的条目
    drop_keys = [
        'mean_luminance', 'std_luminance',
        'under_exposure_ratio', 'over_saturation_ratio',
        'local_rms_contrast', 'color_sat_mean', 'color_sat_std', 'row_noise'
    ]
    for k in drop_keys:
        normed.pop(k, None)
    print(normed)
    return normed


def _process_one_video(images: List[str]):
    imgs = [cv2.imread(f) for f in images]
    adj = video_metrics_v2(imgs)
    this = defaultdict(list)
    for im in imgs:
        fm = single_frame_metrics_v2(im)
        for k, v in fm.items():
            this[k].append(v)
    frame_mean = {k: float(np.nanmean(np.asarray(v))) for k, v in this.items()}
    return adj, frame_mean


@torch.no_grad()
def get_objective_quality_v2(video_list: List[List[str]]) -> Dict[str, float]:
    """Parallel version with p2020_v2; returns dict of mean scores (+ avg)."""
    val_dict = defaultdict(list)
    video_val_dict = defaultdict(list)
    print('=== Start Objective Quality (v2, parallel) ===')

    max_workers = max(1, (os.cpu_count() or 2) - 1)
    print(f'using {max_workers} workers')

    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        for adj_metric, frame_mean in tqdm(ex.map(_process_one_video, video_list), total=len(video_list)):
            for k, v in adj_metric.items():
                video_val_dict[k].append(v)
            for k, v in frame_mean.items():
                val_dict[k].append(v)

    out = {}
    for k, v in val_dict.items():
        score = float(np.nanmean(np.asarray(v)))
        if 'fmp_alias' in k:
            score = 1.0 - score
        out[k] = score
        print(f'{k} score mean: {score:.6f}')

    for k, v in video_val_dict.items():
        score = float(np.nanmean(np.asarray(v)))
        if 'fmp_alias' in k:
            score = 1.0 - score
        out[k] = score
        print(f'video {k} score mean: {score:.6f}')

    out['avg'] = float(np.mean([v for v in out.values()])) if out else 0.0
    print(out)
    return out


# ----------------------------- Subjective (CLIP-IQA+) -----------------------------
@torch.no_grad()
def get_subjective_quality(video_list: List[List[str]]) -> float:
    """Mean CLIPIQA+ score across videos."""
    global subjective_quality_model
    if subjective_quality_model is None:
        import pyiqa
        subjective_quality_model = pyiqa.create_metric('clipiqa+').to(_device()).eval()

    from torchvision import transforms
    from .metrics.base_metrics import open_image

    tfm = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])

    def _stack(img_paths: List[str]) -> torch.Tensor:
        xs: List[torch.Tensor] = []
        for p in img_paths:
            x = tfm(open_image(p))
            xs.append(x)
        return torch.stack(xs).to(_device())

    print('=== Start Subjective Quality ===')
    scores = []
    for imgs in tqdm(video_list):
        batch = _stack(imgs)
        with torch.no_grad():
            s = subjective_quality_model(batch)
        scores.append(float(s.mean().item()))
    mean_score = float(np.mean(scores)) if scores else 0.0
    print(f'raw: {mean_score}')
    return mean_score


# ----------------------------- Scene consistency (DINOv3) -----------------------------
def init_dinov3_model():
    """Lazy-load DINOv3 from ModelScope/Transformers."""
    global dinov3
    if dinov3 is not None:
        return
    from modelscope import AutoImageProcessor, AutoModel
    processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vith16plus-pretrain-lvd1689m")
    model = AutoModel.from_pretrained("facebook/dinov3-vith16plus-pretrain-lvd1689m", device_map='auto')
    dinov3 = (model, processor)


@torch.no_grad()
def get_scene_consistency_v3(video_list: List[List[str]], names: Optional[List[str]] = None) -> float:
    """
    Motion-downsampled DINOv3 consistency.
    `names`：与 video_list 等长的目录列表，用于缓存 flow.pkl 与 dino.pkl
    """
    init_flow_model()
    init_dinov3_model()
    model, processor = dinov3
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    print('=== Start Scene Consistency v3 (motion-downsampled, cached) ===')
    scores = []
    for idx, frames in tqdm(list(enumerate(video_list))):
        cache_dir = names[idx] if names and idx < len(names) else None

        # Flow cache
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)
            flow_pkl = os.path.join(cache_dir, 'flow.pkl')
        else:
            flow_pkl = None

        if flow_pkl and os.path.exists(flow_pkl):
            with open(flow_pkl, 'rb') as f:
                mags = pickle.load(f)
        else:
            mags = compute_motion_series(frames) if len(frames) >= 2 else np.array([])
            if flow_pkl:
                with open(flow_pkl, 'wb') as f:
                    pickle.dump(mags, f, protocol=pickle.HIGHEST_PROTOCOL)

        # DINO cache
        if cache_dir:
            dino_pkl = os.path.join(cache_dir, 'dino.pkl')
        else:
            dino_pkl = None

        if dino_pkl and os.path.exists(dino_pkl):
            with open(dino_pkl, 'rb') as f:
                feats = pickle.load(f)
        else:
            feats = []
            for p in frames:
                img = _imread_rgb(p)
                inputs = processor(images=img, return_tensors="pt").to(model.device)
                with torch.inference_mode():
                    out = model(**inputs)
                feats.append(out.pooler_output.squeeze(0).detach().cpu())
            if dino_pkl:
                with open(dino_pkl, 'wb') as f:
                    pickle.dump(feats, f, protocol=pickle.HIGHEST_PROTOCOL)

        idxs = select_indices_by_arc_length_abs(mags, v_low=1.0, v_high=10.0, min_k=3, max_k=20)
        if len(idxs) < 2:
            scores.append(0.0)
            continue

        F = torch.stack([feats[k] for k in idxs], dim=0)  # (M,D) on CPU
        F = F.to(torch.float32)
        s2 = cos(F[:-1], F[1:]).clamp_min(0).mean().item()
        scores.append(float(s2))

    score = float(np.mean(scores)) if scores else 0.0
    print(f'[motion-downsampled] dino consistency raw: {score}')
    return score


# -------- Parallel readonly scorer for cached flow/dino --------
EPS = 1e-8

def _safe_load_cpu(path):
    try:
        return torch.load(path, map_location='cpu')
    except Exception:
        with open(path, 'rb') as f:
            return pickle.load(f)

def _to_numpy_rows(feats):
    out = []
    for x in feats:
        if isinstance(x, torch.Tensor):
            out.append(x.detach().cpu().numpy().astype(np.float32, copy=False))
        else:
            arr = np.asarray(x)
            if arr.ndim == 0:
                arr = arr[None]
            out.append(arr.astype(np.float32, copy=False))
    return out

def _cos_adjacent_np(F):
    A, B = F[:-1], F[1:]
    num = (A * B).sum(axis=1)
    da = np.linalg.norm(A, axis=1)
    db = np.linalg.norm(B, axis=1)
    sim = num / (da * db + EPS)
    sim = np.clip(sim, 0.0, 1.0)
    return float(sim.mean())

def _score_one_video_worker(name_dir, v_low, v_high, min_k, max_k):
    torch.set_num_threads(1)
    flow_pkl = os.path.join(name_dir, "flow.pkl")
    dino_pkl = os.path.join(name_dir, "dino.pkl")
    mags = _safe_load_cpu(flow_pkl)
    feats = _safe_load_cpu(dino_pkl)
    idxs = select_indices_by_arc_length_abs(mags, v_low=v_low, v_high=v_high, min_k=min_k, max_k=max_k)
    if len(idxs) < 2:
        return 0.0
    rows = _to_numpy_rows(feats)
    F = np.stack([rows[k] for k in idxs], axis=0).astype(np.float32, copy=False)
    return _cos_adjacent_np(F)

@torch.no_grad()
def get_scene_consistency_v3_parallel(
    names: List[str],
    max_workers: Optional[int] = 64,
    chunksize: int = 1,
    v_low: float = 1.0, v_high: float = 10.0, min_k: int = 3, max_k: int = 20,
) -> Tuple[float, List[float]]:
    """Read-only parallel scorer using precomputed {flow.pkl,dino.pkl} under each `names[i]` dir."""
    print('=== Start Scene Consistency v3_parallel (cached, readonly) ===')
    scores = []
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        it = ex.map(
            _score_one_video_worker,
            names, repeat(v_low), repeat(v_high), repeat(min_k), repeat(max_k),
            chunksize=chunksize,
        )
        for s in tqdm(it, total=len(names), desc="Scene Consistency (proc/fork)"):
            scores.append(float(s))
    mean_score = float(np.mean(scores)) if scores else 0.0
    print(f'[cached numpy] dino consistency raw: {mean_score}')
    return mean_score, scores
