"""
Calculates the Ewarp metric for temporal consistency using RAFT optical flow.
"""

import os
import sys
import argparse
import subprocess
import urllib.request
import time
import shutil
import math
from types import SimpleNamespace

import numpy as np
import cv2
import torch

# ------------ Config defaults ------------
RAFT_REPO_URL = "https://github.com/princeton-vl/RAFT.git"
RAFT_LOCAL_DIR = "raft_repo"
CHECKPOINTS_DIR = "checkpoints"
RAFT_CHECKPOINT_NAME = "raft-things.pth"
RAFT_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/raft/raft-things.pth"

# ------------ Utilities ------------
def log(s):
    """Prints a message to stdout."""
    print(s)
    sys.stdout.flush()

def ensure_dir(path):
    """Ensures a directory exists."""
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

def run_cmd(cmd, cwd=None):
    """Executes a shell command."""
    log(f"Running: {' '.join(cmd)}")
    subprocess.check_call(cmd, cwd=cwd)

def download_url(url, dst, desc=None):
    """Downloads a file from a URL to a destination."""
    if os.path.exists(dst):
        log(f"Checkpoint already exists: {dst}")
        return
    ensure_dir(os.path.dirname(dst) or ".")
    log(f"Downloading {desc or url} -> {dst} ...")
    try:
        urllib.request.urlretrieve(url, dst)
    except Exception as e:
        raise RuntimeError(f"Failed to download {url}: {e}")
    log("Download finished.")

def clone_raft_repo(local_dir=RAFT_LOCAL_DIR):
    """Clones the RAFT repository if it doesn't exist."""
    if os.path.exists(local_dir) and os.path.isdir(local_dir):
        log(f"RAFT repo dir already exists: {local_dir}")
        return
    log(f"Cloning RAFT repo into {local_dir} ...")
    run_cmd(["git", "clone", RAFT_REPO_URL, local_dir])

def setup_raft_module(raft_dir=RAFT_LOCAL_DIR):
    """Clones RAFT repo if needed and adds it to Python path."""
    clone_raft_repo(raft_dir)
    abs_dir = os.path.abspath(raft_dir)
    core_dir = os.path.join(abs_dir, "core")
    
    if abs_dir not in sys.path:
        sys.path.insert(0, abs_dir)
    if core_dir not in sys.path:
        sys.path.insert(0, core_dir)
        
    try:
        from raft import RAFT
        import utils
        return True
    except Exception as e:
        raise RuntimeError(f"Failed to import RAFT module from {raft_dir}: {e}")

# ------------ RAFT model load & inference helper ------------
def make_raft_model(device, checkpoint_path):
    """Instantiates the RAFT model and loads a checkpoint."""
    try:
        from raft import RAFT
    except Exception as e:
        raise RuntimeError("Cannot import RAFT module. Make sure RAFT repo is cloned and in sys.path.") from e

    # Create a minimal args Namespace to satisfy the RAFT constructor.
    import argparse
    args = argparse.Namespace()
    args.small = False
    args.mixed_precision = False
    args.alternate_corr = False

    model = RAFT(args)
    log(f"Loading RAFT checkpoint: {checkpoint_path}")
    ckpt = torch.load(checkpoint_path, map_location=device)
    state_dict = ckpt.get('state_dict', ckpt)

    # Handle "module." prefix from DataParallel models.
    new_state = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state, strict=False)
    model.to(device)
    model.eval()
    return model

def raft_flow_between_frames(model, img_curr, img_prev, device, iters=20):
    """
    Computes optical flow between two frames using RAFT.

    Args:
        model: The loaded RAFT model.
        img_curr, img_prev: Numpy arrays (H,W,3) in RGB, uint8 or float32 [0,1].
        device: The torch device to run on.
        iters (int): Number of refinement iterations for RAFT.

    Returns:
        np.ndarray: A (H,W,2) flow field where flow[y,x] is the displacement
                    vector to map a pixel from img_curr to img_prev.
    """
    if img_curr.dtype == np.uint8:
        img_curr_f = img_curr.astype(np.float32) / 255.0
        img_prev_f = img_prev.astype(np.float32) / 255.0
    else:
        img_curr_f = img_curr.astype(np.float32)
        img_prev_f = img_prev.astype(np.float32)

    im1 = torch.from_numpy(img_curr_f.transpose(2,0,1)).unsqueeze(0).to(device)
    im2 = torch.from_numpy(img_prev_f.transpose(2,0,1)).unsqueeze(0).to(device)

    with torch.no_grad():
        out = model(im1, im2, iters=iters, test_mode=True)
    
    flow = out[-1] if isinstance(out, (list, tuple)) else out
    
    if isinstance(flow, torch.Tensor):
        flow_np = flow[0].cpu().numpy().transpose(1,2,0) # H,W,2 (x-flow, y-flow)
    else:
        raise RuntimeError(f"Unexpected RAFT output type: {type(flow)}")
    
    return flow_np.astype(np.float32)

# ------------ Video read / preprocessing ------------
def read_first_n_frames(video_path, nframes, resize_hw=(368,496)):
    """
    Reads the first N frames from a video and resizes them.

    Returns:
        list: A list of BGR, uint8 frames.
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {video_path}")
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total < nframes:
        cap.release()
        raise RuntimeError(f"Video {video_path} has {total} frames < required {nframes}")
    frames = []
    Ht, Wt = resize_hw
    for i in range(nframes):
        ret, frm = cap.read()
        if not ret or frm is None:
            cap.release()
            raise RuntimeError(f"Failed to read frame {i} from {video_path}")
        if (frm.shape[0], frm.shape[1]) != (Ht, Wt):
            frm = cv2.resize(frm, (Wt, Ht), interpolation=cv2.INTER_LINEAR)
        frames.append(frm)
    cap.release()
    return frames

def bgr_to_rgb_float01(frames_bgr):
    """Converts a list of BGR uint8 frames to RGB float32 [0,1]."""
    out = []
    for f in frames_bgr:
        rgb = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
        rgb_f = rgb.astype(np.float32) / 255.0
        out.append(rgb_f)
    return out

# ------------ Warping & Ewarp computation ------------
def warp_prev_to_curr(prev_rgb01, curr_rgb01, flow_curr2prev):
    """
    Warps the previous frame to align with the current frame using optical flow.

    Args:
        prev_rgb01 (np.ndarray): Previous frame, float32 [0,1], shape (H,W,3).
        curr_rgb01 (np.ndarray): Current frame, float32 [0,1], shape (H,W,3).
        flow_curr2prev (np.ndarray): Flow field from current to previous.

    Returns:
        np.ndarray: The warped previous frame, aligned with the current frame.
    """
    H, W = prev_rgb01.shape[:2]
    xs = np.arange(W, dtype=np.float32)
    ys = np.arange(H, dtype=np.float32)
    grid_x, grid_y = np.meshgrid(xs, ys)
    map_x = (grid_x + flow_curr2prev[...,0]).astype(np.float32)
    map_y = (grid_y + flow_curr2prev[...,1]).astype(np.float32)
    
    warped = cv2.remap((prev_rgb01 * 255.0).astype(np.uint8), map_x, map_y,
                       interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    warped = warped.astype(np.float32) / 255.0
    return warped

def compute_ewarp_for_view(enh_video_path, model, device, resize_hw=(368,496), nframes=300, save_per_frame=None):
    """
    Computes the Ewarp score for a single video.

    Returns:
        tuple: (mean_ewarp_score, array_of_per_frame_scores)
    """
    log(f"Reading first {nframes} frames from {enh_video_path} ...")
    frames_bgr = read_first_n_frames(enh_video_path, nframes, resize_hw=resize_hw)
    frames_rgb = bgr_to_rgb_float01(frames_bgr)

    per_frame_vals = []
    t_start = time.time()
    for t in range(1, nframes):
        curr = frames_rgb[t]
        prev = frames_rgb[t-1]
        
        flow = raft_flow_between_frames(model, curr, prev, device)
        warped_prev = warp_prev_to_curr(prev, curr, flow)
        diff = np.abs(curr - warped_prev)
        per_frame_mean = float(diff.mean())
        per_frame_vals.append(per_frame_mean)
        
        if (t % 50) == 0:
            log(f"  frame {t}/{nframes-1}: per-frame L1 = {per_frame_mean:.6f}")
    
    dur = time.time() - t_start
    mean_ewarp = float(np.mean(per_frame_vals)) if len(per_frame_vals)>0 else 0.0
    log(f"Finished view. mean Ewarp over {nframes-1} frames = {mean_ewarp:.6f} (time {dur:.1f}s)")
    
    if save_per_frame is not None:
        ensure_dir(os.path.dirname(save_per_frame) or ".")
        np.save(save_per_frame, np.array(per_frame_vals, dtype=np.float32))
        log(f"Saved per-frame values to {save_per_frame}")
    
    return mean_ewarp, np.array(per_frame_vals, dtype=np.float32)

# ------------ CLI & main ------------
def parse_args():
    """Parses command-line arguments."""
    p = argparse.ArgumentParser(description="Compute Ewarp (L1-based) using RAFT optical flow on enhanced videos.")
    p.add_argument("--enh_dir", type=str, required=True, help="Directory with enhanced videos (e.g., head.mp4, right_hand.mp4).")
    p.add_argument("--device", type=str, default="cuda:0", help="Device to use, e.g., 'cuda:0' or 'cpu'.")
    p.add_argument("--no_gpu", action="store_true", help="Force CPU usage, overriding --device.")
    p.add_argument("--resize_h", type=int, default=368, help="Frame resize height.")
    p.add_argument("--resize_w", type=int, default=496, help="Frame resize width.")
    p.add_argument("--nframes", type=int, default=300, help="Number of frames to process from the start of the video.")
    p.add_argument("--raft_dir", type=str, default='raft_repo', help="Local directory for the RAFT repository (will be cloned if not found).")
    p.add_argument("--ckpt_dir", type=str, default='raft_repo/models', help="Directory to store model checkpoints.")
    p.add_argument("--ckpt_name", type=str, default='raft-things.pth', help="RAFT checkpoint filename.")
    p.add_argument("--save_per_frame_outdir", type=str, default=None, help="If set, saves per-frame Ewarp values to this directory.")
    p.add_argument("--no_download_ckpt", action="store_true", help="Disable automatic checkpoint download (requires checkpoint to exist).")
    return p.parse_args()

def main():
    args = parse_args()
    if args.no_gpu:
        device = torch.device("cpu")
    else:
        device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    log(f"Device: {device}")

    # Check for view videos
    views = ["head.mp4", "right_hand.mp4", "left_hand.mp4"]
    view_keys = ["head", "right_hand", "left_hand"]
    for v in views:
        path = os.path.join(args.enh_dir, v)
        if not os.path.isfile(path):
            raise FileNotFoundError(f"Enhanced video not found: {path}")

    # Setup RAFT and ensure checkpoint exists
    setup_raft_module(args.raft_dir)
    ensure_dir(args.ckpt_dir)
    ckpt_path = os.path.join(args.ckpt_dir, args.ckpt_name)
    if not args.no_download_ckpt and not os.path.exists(ckpt_path):
        log("Checkpoint not found locally — downloading...")
        download_url(RAFT_CHECKPOINT_URL, ckpt_path, desc="raft-things.pth")
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"RAFT checkpoint not found at {ckpt_path}. Use --no_download_ckpt only if checkpoint exists.")

    # Load model
    model = make_raft_model(device, ckpt_path)

    results = {}
    per_frame_savedir = args.save_per_frame_outdir
    if per_frame_savedir:
        ensure_dir(per_frame_savedir)

    for vk, vf in zip(view_keys, views):
        enh_path = os.path.join(args.enh_dir, vf)
        log(f"Processing view {vk} ...")
        start = time.time()
        
        save_path = os.path.join(per_frame_savedir, f"{vk}_per_frame.npy") if per_frame_savedir else None
        mean_ewarp, _ = compute_ewarp_for_view(
            enh_path, model, device,
            resize_hw=(args.resize_h, args.resize_w),
            nframes=args.nframes,
            save_per_frame=save_path
        )
        
        dur = time.time() - start
        log(f"[{vk}] Ewarp (mean L1) = {mean_ewarp:.6f}  (time {dur:.1f}s)")
        results[vk] = float(mean_ewarp)

    vals = np.array(list(results.values()), dtype=np.float64)
    mean_val = float(np.mean(vals))
    std_val = float(np.std(vals, ddof=0))
    
    log("----- Final Results -----")
    for k, v in results.items():
        log(f"{k}: {v:.6f}")
    log(f"Mean: {mean_val:.6f}")
    log(f"Std : {std_val:.6f}")

if __name__ == "__main__":
    main()