"""
Computes Multi-View Depth Consistency (MVDC) for a set of multi-view videos.
"""
import os
import cv2
import torch
import argparse
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as T

def load_midas(device):
    """
    Loads the MiDaS DPT-Hybrid model for depth estimation.

    Args:
        device: The device to load the model on.

    Returns:
        tuple: A tuple containing the loaded model and the appropriate transform.
    """
    midas = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid", pretrained=True).to(device).eval()
    transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
    transform = transforms.dpt_transform
    return midas, transform

def load_video_frames(path, num_frames=300, max_frames=1500):
    """
    Samples and loads a specified number of frames from a video file.

    Args:
        path (str): The path to the video file.
        num_frames (int): The number of frames to sample uniformly.
        max_frames (int): The maximum number of frames to consider from the video.

    Returns:
        list: A list of RGB frames as numpy arrays.
    """
    cap = cv2.VideoCapture(path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total < num_frames:
        raise ValueError(f"Video {path} has less than {num_frames} frames!")

    idxs = np.linspace(0, min(max_frames, total - 1), num_frames, dtype=int)
    frames = []
    for i in range(total):
        ret, frame = cap.read()
        if not ret:
            break
        if i in idxs:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
    cap.release()
    return frames

def extract_depth(frames, model, transform, device, size=256):
    """
    Extracts and normalizes depth maps for a list of frames.

    Args:
        frames (list): A list of RGB frames.
        model (torch.nn.Module): The MiDaS depth estimation model.
        transform: The transformation to apply to the frames.
        device: The device for torch computations.
        size (int): The size to which the depth maps are resized.

    Returns:
        np.ndarray: A stacked array of normalized depth maps.
    """
    depths = []
    with torch.no_grad():
        for f in frames:
            input_tensor = transform(f).to(device)
            pred = model(input_tensor).squeeze()
            pred = torch.nn.functional.interpolate(
                pred.unsqueeze(0).unsqueeze(0),
                size=(size, size),
                mode="bicubic",
                align_corners=False,
            ).squeeze()
            depth = pred.cpu().numpy()
            depth = (depth - depth.mean()) / (depth.std() + 1e-6)  # normalize
            depths.append(depth)
    return np.stack(depths)

def compute_mvdc(depths_dict):
    """
    Computes the Multi-View Depth Consistency (MVDC) score.

    The score is the mean absolute difference between depth maps from three
    views (head, right_hand, left_hand), averaged over all frames.

    Args:
        depths_dict (dict): A dictionary containing depth map arrays for
            'head', 'right_hand', and 'left_hand' views.

    Returns:
        float: The final MVDC score.
    """
    scores = []
    for d_head, d_rh, d_lh in zip(depths_dict["head"], depths_dict["right_hand"], depths_dict["left_hand"]):
        diff_hr = np.mean(np.abs(d_head - d_rh))
        diff_hl = np.mean(np.abs(d_head - d_lh))
        diff_rl = np.mean(np.abs(d_rh - d_lh))
        scores.append((diff_hr + diff_hl + diff_rl) / 3.0)
    return float(np.mean(scores))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compute Multi-View Depth Consistency (MVDC).")
    parser.add_argument("--gt", type=str, default="./ground_truth",
                        help="Directory with ground-truth videos (e.g., head.mp4).")
    parser.add_argument("--gen", type=str, required=True,
                        help="Directory with generated videos to evaluate (e.g., head.mp4).")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, transform = load_midas(device)

    results = {}
    for name, path in {"GT": args.gt, "GEN": args.gen}.items():
        depths_dict = {}
        for view in ["head", "right_hand", "left_hand"]:
            video_path = os.path.join(path, f"{view}.mp4")
            frames = load_video_frames(video_path)
            depths = extract_depth(frames, model, transform, device)
            depths_dict[view] = depths
        results[name] = compute_mvdc(depths_dict)

    print("=== Multi-View Depth Consistency (MVDC) ===")
    for k, v in results.items():
        print(f"{k}: {v:.4f}")
