"""
Computes temporal LPIPS for videos over varying frame ranges.
"""

import os
import sys
import argparse
import time
import numpy as np
import cv2
import torch

def log(msg):
    """Prints a message to stdout."""
    print(msg); sys.stdout.flush()

def ensure_dir(path):
    """Creates a directory if it does not exist."""
    if path:
        os.makedirs(path, exist_ok=True)

def ensure_lpips():
    """Checks if the lpips package is installed and raises an error if not."""
    try:
        import lpips  # noqa: F401
    except Exception:
        raise RuntimeError("lpips package not found. Install with: pip install lpips")

def make_lpips_model(device, net='alex'):
    """
    Creates and returns an LPIPS model.

    Args:
        device: The device to load the model on.
        net (str): The network backbone to use ('alex' or 'vgg').

    Returns:
        torch.nn.Module: The LPIPS model in evaluation mode.
    """
    import lpips
    m = lpips.LPIPS(net=net).to(device)
    m.eval()
    return m

def get_video_frame_count(path):
    """Gets the total number of frames in a video."""
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {path}")
    count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return count

def sample_frame_indices_uniform(max_frame_idx, num_samples=30):
    """
    Uniformly samples frame indices from the range [0, max_frame_idx].

    Args:
        max_frame_idx (int): The maximum frame index (inclusive).
        num_samples (int): The number of frames to sample.

    Returns:
        list: A list of sampled frame indices.
    """
    if max_frame_idx + 1 < num_samples:
        raise ValueError(f"Not enough frames to sample {num_samples} from {max_frame_idx + 1}")
    indices = np.linspace(0, max_frame_idx, num=num_samples, dtype=np.int64)
    return indices.tolist()

def read_frames_at_indices(video_path, indices, resize_hw=(368,496)):
    """
    Reads frames at given 0-based frame indices.

    Args:
        video_path (str): Path to the video file.
        indices (list): A list of 0-based frame indices to read.
        resize_hw (tuple): The (height, width) to resize frames to.

    Returns:
        list: A list of BGR frames (np.uint8).
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video {video_path}")
    frames = []
    Ht, Wt = resize_hw
    for i in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(i))
        ret, frm = cap.read()
        if not ret or frm is None:
            cap.release()
            raise RuntimeError(f"Failed to read frame {i} from {video_path}")
        if (frm.shape[0], frm.shape[1]) != (Ht, Wt):
            frm = cv2.resize(frm, (Wt, Ht), interpolation=cv2.INTER_LINEAR)
        frames.append(frm)
    cap.release()
    return frames  # BGR uint8

def bgr_to_rgb_float01(frames_bgr):
    """Converts a list of BGR frames to RGB float frames in [0, 1]."""
    out = []
    for f in frames_bgr:
        rgb = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
        out.append(rgb.astype(np.float32) / 255.0)
    return out

def compute_temporal_lpips_for_video_series(video_path, lpips_model, device, a,
                                           resize_hw=(368,496), lpips_net='alex'):
    """
    Computes temporal LPIPS for a video, sampled from its first 'a' frames.

    This function takes the first 'a' frames of a video, samples 30 frames
    uniformly from this segment, and then computes the mean LPIPS between
    consecutive frames in the sampled set.

    Args:
        video_path (str): Path to the video file.
        lpips_model (torch.nn.Module): The pretrained LPIPS model.
        device: The device for torch computations.
        a (int): The number of initial frames to consider for sampling.
        resize_hw (tuple): The (height, width) for resizing frames.
        lpips_net (str): The LPIPS network backbone.

    Returns:
        tuple: A tuple containing:
            - float: The mean temporal LPIPS value.
            - np.ndarray: An array of LPIPS values for each frame pair.
    """
    # Get total frame count
    total_frames = get_video_frame_count(video_path)
    if total_frames < 2:
        raise RuntimeError(f"Need >=2 frames to compute temporal LPIPS for {video_path} (got {total_frames})")
    
    # Determine actual frame range (min of a and total_frames)
    actual_a = min(a, total_frames)
    max_frame_idx = actual_a - 1
    
    # Sample 30 frames uniformly from [0, max_frame_idx]
    indices = sample_frame_indices_uniform(max_frame_idx, num_samples=30)
    
    # log(f"Reading {len(indices)} frames from {video_path} (a={a}, sampled from first {actual_a} frames) ...")
    frames_bgr = read_frames_at_indices(video_path, indices, resize_hw=resize_hw)
    frames_rgb = bgr_to_rgb_float01(frames_bgr)  # list of H,W,3 float32 in [0,1]
    T = len(frames_rgb)
    if T < 2:
        raise RuntimeError(f"Need >=2 frames to compute temporal LPIPS (got {T})")

    per_vals = []
    t0 = time.time()
    for t in range(T-1):
        a_frame = frames_rgb[t]
        b_frame = frames_rgb[t+1]
        # prepare tensors in [-1,1], shape (1,3,H,W)
        at = torch.from_numpy((a_frame * 2.0 - 1.0).transpose(2,0,1)).unsqueeze(0).to(device)
        bt = torch.from_numpy((b_frame * 2.0 - 1.0).transpose(2,0,1)).unsqueeze(0).to(device)
        with torch.no_grad():
            d = lpips_model(at, bt)
            # lpips returns tensor-like; take scalar mean
            if isinstance(d, torch.Tensor):
                val = float(d.mean().cpu().item())
            else:
                val = float(np.array(d).mean())
        per_vals.append(val)
        # if (t+1) % 10 == 0 or (t == T-2):
            # log(f"  frame pair {t}->{t+1}: LPIPS={val:.6f}")
    dur = time.time() - t0
    per_arr = np.array(per_vals, dtype=np.float32)
    mean_lpips = float(per_arr.mean())
    # log(f"Finished a={a}: mean temporal LPIPS = {mean_lpips:.6f} (computed {T-1} pairs in {dur:.1f}s)")
    return mean_lpips, per_arr

def parse_args():
    """Parses command-line arguments."""
    p = argparse.ArgumentParser(description="Compute Temporal LPIPS for series of frame ranges (a=30,60,...,300).")
    p.add_argument("--enh_dirs", type=str, required=True, 
                   help="Comma-separated list of enhanced video directories.")
    p.add_argument("--orig_dir", type=str, default="gt",
                   help="Directory with original videos.")
    p.add_argument("--resize_h", type=int, default=368, help="resize height (default 368)")
    p.add_argument("--resize_w", type=int, default=496, help="resize width (default 496)")
    p.add_argument("--device", type=str, default="cuda:0", help="Device to use, e.g., 'cuda:0' or 'cpu'.")
    p.add_argument("--no_gpu", action="store_true", help="Force CPU usage (overrides device).")
    p.add_argument("--lpips_net", type=str, default="alex", choices=["alex","vgg"], help="LPIPS backbone (alex or vgg)")
    return p.parse_args()

def main():
    args = parse_args()
    # device selection
    if args.no_gpu:
        device = torch.device("cpu")
    else:
        if torch.cuda.is_available():
            device = torch.device(args.device)
        else:
            device = torch.device("cpu")
    # log(f"Using device: {device}")

    # parse enh_dirs
    enh_dirs = [d.strip() for d in args.enh_dirs.split(",") if d.strip()]
    if not enh_dirs:
        raise ValueError("No enhancement directories provided")

    # video names
    video_names = ["head.mp4", "right_hand.mp4", "left_hand.mp4"]
    view_keys = ["head", "right_hand", "left_hand"]

    # ensure lpips available and create model
    ensure_lpips()
    lpips_model = make_lpips_model(device, net=args.lpips_net)

    # a values: 30, 60, ..., 300
    a_values = list(range(30, 301, 30))  # 10 values

    # Process each enhancement directory
    for enh_dir in enh_dirs:
        # log(f"\n{'='*60}")
        # log(f"Processing Model: {enh_dir}")
        # log(f"{'='*60}")
        
        # Check all videos exist
        for vn in video_names:
            orig_path = os.path.join(args.orig_dir, vn)
            enh_path = os.path.join(enh_dir, vn)
            if not os.path.isfile(orig_path):
                raise FileNotFoundError(f"Original video not found: {orig_path}")
            if not os.path.isfile(enh_path):
                raise FileNotFoundError(f"Enhanced video not found: {enh_path}")

        # For each a value, compute LPIPS for all 3 views
        results = {}  # results[a] = {view_key: lpips_value}
        
        for a in a_values:
            # log(f"\n--- Processing a = {a} ---")
            view_lpips_values = {}
            
            for vk, vn in zip(view_keys, video_names):
                orig_path = os.path.join(args.orig_dir, vn)
                enh_path = os.path.join(enh_dir, vn)
                
                try:
                    # Compute LPIPS for enhanced video
                    mean_lpips, _ = compute_temporal_lpips_for_video_series(
                        enh_path, lpips_model, device, a,
                        resize_hw=(args.resize_h, args.resize_w),
                        lpips_net=args.lpips_net
                    )
                    view_lpips_values[vk] = mean_lpips
                    # log(f"  {vk}: {mean_lpips:.6f}")
                except Exception as e:
                    # log(f"  Error processing {vk} at a={a}: {e}")
                    view_lpips_values[vk] = float('nan')
            
            results[a] = view_lpips_values

        # Output summary for this model: for each a, show max, min, mean across views
        log(f"\n{'='*60}")
        log(f"Summary for Model: {enh_dir}")
        log(f"{'='*60}")
        log("a_value  max_lpips  min_lpips  mean_lpips")
        log("-" * 40)
        
        for a in a_values:
            vals = []
            for vk in view_keys:
                val = results[a].get(vk, float('nan'))
                if not np.isnan(val):
                    vals.append(val)
            
            if vals:
                max_val = max(vals)
                min_val = min(vals)
                mean_val = sum(vals) / len(vals)
                log(f"{a:6d}   {max_val:8.4f}   {min_val:8.4f}   {mean_val:9.4f}")
            else:
                log(f"{a:6d}   No valid data")

if __name__ == "__main__":
    main()