#!/usr/bin/env python3
"""
Computes temporal LPIPS for a set of videos.
"""

import os
import sys
import argparse
import time
import numpy as np
import cv2
import torch

def log(msg):
    """Prints a message to stdout."""
    print(msg); sys.stdout.flush()

def ensure_dir(path):
    """Creates a directory if it does not exist."""
    if path:
        os.makedirs(path, exist_ok=True)

def ensure_lpips():
    """Checks if the lpips package is installed and raises an error if not."""
    try:
        import lpips  # noqa: F401
    except Exception:
        raise RuntimeError("lpips package not found. Install with: pip install lpips")

def make_lpips_model(device, net='alex'):
    """
    Creates and returns an LPIPS model.

    Args:
        device: The device to load the model on.
        net (str): The network backbone to use ('alex' or 'vgg').

    Returns:
        torch.nn.Module: The LPIPS model in evaluation mode.
    """
    import lpips
    m = lpips.LPIPS(net=net).to(device)
    m.eval()
    return m

def read_frames_from_video(video_path, max_frames=None, resize_hw=(368,496)):
    """
    Reads frames sequentially from a video file.

    Args:
        video_path (str): Path to the video file.
        max_frames (int, optional): Maximum number of frames to read. Defaults to None (all frames).
        resize_hw (tuple): The (height, width) to resize frames to.

    Returns:
        list: A list of BGR frames (np.uint8).
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {video_path}")
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
    n_to_read = total if (max_frames is None) else min(total if total>0 else max_frames, max_frames)
    frames = []
    Ht, Wt = resize_hw
    i = 0
    while True:
        ret, frm = cap.read()
        if not ret:
            break
        if (frm.shape[0], frm.shape[1]) != (Ht, Wt):
            frm = cv2.resize(frm, (Wt, Ht), interpolation=cv2.INTER_LINEAR)
        frames.append(frm)
        i += 1
        if (max_frames is not None) and (i >= max_frames):
            break
    cap.release()
    if len(frames) == 0:
        raise RuntimeError(f"No frames read from {video_path}")
    return frames  # BGR uint8

def bgr_to_rgb_float01(frames_bgr):
    """Converts a list of BGR frames to RGB float frames in [0, 1]."""
    out = []
    for f in frames_bgr:
        rgb = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
        out.append(rgb.astype(np.float32) / 255.0)
    return out

def compute_temporal_lpips_for_video(video_path, lpips_model, device,
                                     resize_hw=(368,496), nframes=None,
                                     save_per_frame_path=None, lpips_net='alex'):
    """
    Computes the temporal LPIPS for a single video.

    Args:
        video_path (str): Path to the video file.
        lpips_model (torch.nn.Module): The pretrained LPIPS model.
        device: The device for torch computations.
        resize_hw (tuple): The (height, width) for resizing frames.
        nframes (int, optional): Maximum number of frames to process. Defaults to None.
        save_per_frame_path (str, optional): If provided, saves the per-frame LPIPS
            values to this .npy file. Defaults to None.
        lpips_net (str): The LPIPS network backbone.

    Returns:
        tuple: A tuple containing:
            - float: The mean temporal LPIPS value.
            - np.ndarray: An array of LPIPS values for each frame pair.
    """
    log(f"Reading frames from {video_path} (max {nframes}) ...")
    frames_bgr = read_frames_from_video(video_path, max_frames=nframes, resize_hw=resize_hw)
    frames_rgb = bgr_to_rgb_float01(frames_bgr)  # list of H,W,3 float32 in [0,1]
    T = len(frames_rgb)
    if T < 2:
        raise RuntimeError(f"Need >=2 frames to compute temporal LPIPS for {video_path} (got {T})")

    per_vals = []
    t0 = time.time()
    for t in range(T-1):
        a = frames_rgb[t]
        b = frames_rgb[t+1]
        # prepare tensors in [-1,1], shape (1,3,H,W)
        at = torch.from_numpy((a * 2.0 - 1.0).transpose(2,0,1)).unsqueeze(0).to(device)
        bt = torch.from_numpy((b * 2.0 - 1.0).transpose(2,0,1)).unsqueeze(0).to(device)
        with torch.no_grad():
            d = lpips_model(at, bt)
            # lpips returns tensor-like; take scalar mean
            if isinstance(d, torch.Tensor):
                val = float(d.mean().cpu().item())
            else:
                val = float(np.array(d).mean())
        per_vals.append(val)
        if (t+1) % 50 == 0 or (t == T-2):
            log(f"  frame pair {t}->{t+1}: LPIPS={val:.6f}")
    dur = time.time() - t0
    per_arr = np.array(per_vals, dtype=np.float32)
    mean_lpips = float(per_arr.mean())
    log(f"Finished {os.path.basename(video_path)}: mean temporal LPIPS = {mean_lpips:.6f}  (computed {T-1} pairs in {dur:.1f}s)")
    if save_per_frame_path:
        ensure_dir(os.path.dirname(save_per_frame_path) or ".")
        np.save(save_per_frame_path, per_arr)
        log(f"Saved per-frame LPIPS to {save_per_frame_path}")
    return mean_lpips, per_arr

def parse_args():
    """Parses command-line arguments."""
    p = argparse.ArgumentParser(description="Compute standard Temporal LPIPS (adjacent frames, no flow).")
    p.add_argument("--enh_dir", type=str, required=True, help="Directory with enhanced videos (default: head.mp4/right_hand.mp4/left_hand.mp4), or specify files with --videos.")
    p.add_argument("--videos", type=str, default=None, help="Comma-separated video filenames or paths (relative to enh_dir or absolute). Example: head.mp4,right_hand.mp4,left_hand.mp4")
    p.add_argument("--nframes", type=int, default=300, help="Maximum number of frames to read (default 300). Reads all frames if video is shorter.")
    p.add_argument("--resize_h", type=int, default=368, help="resize height (default 368)")
    p.add_argument("--resize_w", type=int, default=496, help="resize width (default 496)")
    p.add_argument("--device", type=str, default="cuda:0", help="Device to use, e.g., 'cuda:0' or 'cpu'.")
    p.add_argument("--no_gpu", action="store_true", help="Force CPU usage (overrides device).")
    p.add_argument("--lpips_net", type=str, default="alex", choices=["alex","vgg"], help="LPIPS backbone (alex or vgg)")
    p.add_argument("--save_per_frame_outdir", type=str, default=None, help="If specified, saves per-frame LPIPS for each view as .npy files to this directory.")
    return p.parse_args()

def main():
    args = parse_args()
    # device selection
    if args.no_gpu:
        device = torch.device("cpu")
    else:
        if torch.cuda.is_available():
            device = torch.device(args.device)
        else:
            device = torch.device("cpu")
    log(f"Using device: {device}")

    # determine video list
    if args.videos:
        vids = [v.strip() for v in args.videos.split(",") if v.strip()]
        vids = [v if os.path.isabs(v) else os.path.join(args.enh_dir, v) for v in vids]
    else:
        vids = [os.path.join(args.enh_dir, n) for n in ("head.mp4","right_hand.mp4","left_hand.mp4")]

    # check files exist
    for p in vids:
        if not os.path.isfile(p):
            raise FileNotFoundError(f"Video not found: {p}")

    # ensure lpips available and create model
    ensure_lpips()
    lpips_model = make_lpips_model(device, net=args.lpips_net)

    results = {}
    per_outdir = args.save_per_frame_outdir
    if per_outdir:
        ensure_dir(per_outdir)

    for vpath in vids:
        key = os.path.splitext(os.path.basename(vpath))[0]
        log(f"Processing {key} ...")
        save_path = None
        if per_outdir:
            save_path = os.path.join(per_outdir, f"{key}_per_frame_lpips.npy")
        mean_lpips, per_arr = compute_temporal_lpips_for_video(vpath, lpips_model, device,
                                                               resize_hw=(args.resize_h, args.resize_w),
                                                               nframes=args.nframes,
                                                               save_per_frame_path=save_path,
                                                               lpips_net=args.lpips_net)
        results[key] = float(mean_lpips)

    vals = np.array(list(results.values()), dtype=np.float64)
    mean_all = float(vals.mean())
    std_all = float(vals.std(ddof=0))
    log("----- Final Results -----")
    for k,v in results.items():
        log(f"{k}: {v:.6f}")
    log(f"Mean: {mean_all:.6f}")
    log(f"Std : {std_all:.6f}")

if __name__ == "__main__":
    main()