"""
Computes Frechet Video Distance (FVD) between original and enhanced videos.
"""

import os
import sys
import argparse
import math
import numpy as np
import cv2
import torch
import time
from collections import OrderedDict

# optional imports
try:
    from scipy.linalg import sqrtm
except Exception:
    sqrtm = None

# ---------- Utilities ----------
def log(msg):
    """Prints a message to stdout."""
    print(msg)
    sys.stdout.flush()

def get_video_frame_count(path):
    """Gets the total number of frames in a video."""
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {path}")
    count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return count

def sample_frame_indices(min_len, num_samples=300):
    """Generates a list of uniformly sampled frame indices."""
    # evenly spaced indices from 0..min_len-1 (inclusive)
    if min_len < num_samples:
        raise ValueError(f"min_len ({min_len}) < required samples ({num_samples})")
    idx = np.linspace(0, min_len - 1, num=num_samples, dtype=np.int64)
    return idx.tolist()

def read_frames_at_indices(video_path, indices):
    """
    Reads frames at given 0-based frame indices.
    
    Returns:
        list: A list of BGR frames (np.uint8).
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video {video_path}")
    frames = []
    for i in indices:
        # set to frame i
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(i))
        ret, frm = cap.read()
        if not ret or frm is None:
            # If failed to read (some containers), fallback: try sequential read until get requested frame
            # But here just raise error
            cap.release()
            raise RuntimeError(f"Failed to read frame {i} from {video_path}")
        frames.append(frm)
    cap.release()
    return frames

def preprocess_frames(frames_bgr, resize_hw=(224,224)):
    """
    Converts BGR frames to a normalized RGB tensor for the feature extractor.

    Args:
        frames_bgr (list): List of BGR uint8 frames.
    
    Returns:
        np.ndarray: A (T, C, H, W) float32 array, normalized by ImageNet stats.
    """
    H, W = resize_hw
    processed = []
    for f in frames_bgr:
        # BGR -> RGB
        f_rgb = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
        f_resized = cv2.resize(f_rgb, (W, H), interpolation=cv2.INTER_LINEAR)
        f_resized = f_resized.astype(np.float32) / 255.0
        processed.append(f_resized)
    arr = np.stack(processed, axis=0)  # (T,H,W,3)
    # transpose to (T,3,H,W)
    arr = arr.transpose(0,3,1,2).astype(np.float32)
    # normalize
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1,3,1,1)
    std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1,3,1,1)
    arr = (arr - mean) / std
    return arr  # (T,C,H,W)

def clips_from_frames(frames_arr, clip_len=16):
    """
    Creates a list of video clips from a frame array.

    Args:
        frames_arr (np.ndarray): (T, C, H, W) array of frames.
        clip_len (int): The number of frames in each clip.

    Returns:
        list: A list of clips, each a (C, clip_len, H, W) numpy array.
    """
    T = frames_arr.shape[0]
    nclips = T // clip_len
    clips = []
    for k in range(nclips):
        s = k * clip_len
        clip = frames_arr[s:s+clip_len]  # (clip_len,C,H,W)
        # transpose to (C, T, H, W)
        clip = clip.transpose(1,0,2,3)
        clips.append(clip.copy())
    return clips

# ---------- FVD compute ----------
def compute_stats_from_feats(feats):
    """
    Computes the mean and covariance of a set of features.
    
    Args:
        feats (np.ndarray): (N, D) feature array.
    
    Returns:
        tuple: (mu, cov) as numpy arrays.
    """
    mu = np.mean(feats, axis=0)
    # unbiased cov (N-1 in denominator). This matches many implementations.
    if feats.shape[0] > 1:
        cov = np.cov(feats, rowvar=False)
    else:
        # if only one sample, covariance is zero matrix
        cov = np.zeros((feats.shape[1], feats.shape[1]), dtype=np.float64)
    return mu.astype(np.float64), cov.astype(np.float64)

def sqrtm_newton_schulz(a):
    """Fallback matrix square root via eigen-decomposition."""
    # fallback sqrt via eigen-decomposition (stable enough)
    vals, vecs = np.linalg.eigh(a)
    vals[vals < 0] = 0.0
    sqrt_vals = np.sqrt(vals)
    return (vecs * sqrt_vals) @ vecs.T

def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """
    Computes the Frechet distance between two multivariate Gaussians.
    """
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    assert mu1.shape == mu2.shape
    assert sigma1.shape == sigma2.shape

    diff = mu1 - mu2
    # add small regularization to covariances
    cov_eps = eps * np.eye(sigma1.shape[0], dtype=np.float64)
    sigma1_eps = sigma1 + cov_eps
    sigma2_eps = sigma2 + cov_eps
    # compute sqrt of product
    if sqrtm is not None:
        prod = sigma1_eps.dot(sigma2_eps)
        covmean = sqrtm(prod)
        if np.iscomplexobj(covmean):
            # take real part if imaginary components are small
            if not np.allclose(np.imag(covmean), 0, atol=1e-6):
                log("Warning: sqrtm produced complex values with significant imaginary part; taking real part.")
            covmean = np.real(covmean)
    else:
        # fallback eigen decomposition method
        covmean = sqrtm_newton_schulz(sigma1_eps.dot(sigma2_eps))
    tr_covmean = np.trace(covmean)
    fd = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2.0 * tr_covmean
    # numerical negatives -> clamp to zero
    if fd < 0 and fd > -1e-6:
        fd = 0.0
    return float(fd)

# ---------- Model loading & feature extractor ----------
def try_load_i3d(device):
    """
    Tries to load a pretrained I3D model from PyTorchVideo.

    Returns:
        torch.nn.Module or None: The loaded model in eval mode, or None if it fails.
    """
    try:
        # Attempt to import pytorchvideo's i3d via hub
        # There are several possible APIs; try known hub name
        log("Trying to load I3D from pytorchvideo (if installed)...")
        import pytorchvideo.models.hub as hub
        # Many installations support 'i3d_something' via hub
        # We'll try a few known names. If none present, we fallback.
        candidate_names = ['i3d_slow', 'i3d', 'i3d_r50', 'i3d_nl']
        for name in candidate_names:
            try:
                model = hub.load(name, pretrained=True)
                model = model.to(device).eval()
                log(f"Loaded pytorchvideo model via hub: {name}")
                return model
            except Exception:
                continue
        # fallback: try creating I3D via pytorchvideo.models.i3d if available
        try:
            from pytorchvideo.models import i3d
            # create default I3D model; API may differ by version
            # This block is best-effort; if it fails will be caught
            model = i3d.create_i3d()  # may fail if API differs
            model = model.to(device).eval()
            log("Loaded pytorchvideo.models.i3d.create_i3d()")
            return model
        except Exception:
            pass
    except Exception:
        pass
    return None

def load_fallback_model(device):
    """
    Loads a fallback video model (r3d_18) from torchvision.

    Returns:
        torch.nn.Module: The loaded model in eval mode.
    """
    try:
        import torchvision
        log("Loading fallback model torchvision.models.video.r3d_18(pretrained=True)")
        model = torchvision.models.video.r3d_18(pretrained=True)
        # remove final fc to obtain features? We will let forward produce (B, K) or (B,C,T,H,W) and handle pooling generically
        model = model.to(device).eval()
        return model
    except Exception as e:
        raise RuntimeError("Failed to load any video model. Please install pytorchvideo or torchvision.") from e

def get_model(device):
    """Loads a video feature extractor, preferring I3D and falling back to r3d_18."""
    model = try_load_i3d(device)
    if model is not None:
        return model
    else:
        return load_fallback_model(device)

def extract_features_from_clips(model, clips_list, device, batch_size=8):
    """
    Extracts features from a list of video clips using the provided model.

    Args:
        model: The feature extraction model.
        clips_list (list): A list of preprocessed video clips.
        device: The device to run inference on.
        batch_size (int): Batch size for inference.

    Returns:
        np.ndarray: (N, D) feature array, where N is the number of clips.
    """
    model_device = device
    model = model
    n = len(clips_list)
    features = []
    with torch.no_grad():
        for i in range(0, n, batch_size):
            batch_clips = clips_list[i:i+batch_size]
            # stack into (B, C, T, H, W)
            batch = np.stack(batch_clips, axis=0)
            batch_t = torch.from_numpy(batch).to(model_device)  # float32
            # forward
            out = model(batch_t)
            # out can be torch.Tensor or dict; handle typical cases
            if isinstance(out, dict):
                # pick first tensor value
                out_vals = None
                for v in out.values():
                    out_vals = v
                    break
                out_t = out_vals
            else:
                out_t = out
            if not isinstance(out_t, torch.Tensor):
                raise RuntimeError(f"Unexpected model output type: {type(out_t)}")
            # out_t: shape could be (B,C) or (B,C,T,H,W) or (B,C,H,W) etc.
            if out_t.dim() == 2:
                feat = out_t.cpu().numpy()
            else:
                # pool over remaining dims 2..end
                pooled = out_t.mean(dim=tuple(range(2, out_t.dim())))
                feat = pooled.cpu().numpy()
            features.append(feat)
    if len(features) == 0:
        return np.zeros((0,0), dtype=np.float32)
    features = np.concatenate(features, axis=0)
    return features

# ---------- Main flow for one view pair ----------
def compute_fvd_for_view(orig_path, enh_path, model, device, clip_len=16, resize=(224,224), batch_size=8, save_feats_outdir=None, view_name="view"):
    """Computes FVD for a single pair of original and enhanced videos."""
    # count frames
    n_orig = get_video_frame_count(orig_path)
    n_enh = get_video_frame_count(enh_path)
    min_len = min(1500, n_orig, n_enh)
    if min_len < 300:
        raise RuntimeError(f"For view {view_name}: min(1500, n_orig={n_orig}, n_enh={n_enh}) = {min_len} < 300. Cannot proceed.")
    indices = sample_frame_indices(min_len, num_samples=300)
    # read frames for both videos at same indices
    log(f"[{view_name}] Reading frames from original: {orig_path}")
    frames_orig = read_frames_at_indices(orig_path, indices)
    log(f"[{view_name}] Reading frames from enhanced: {enh_path}")
    frames_enh = read_frames_at_indices(enh_path, indices)
    # preprocess (resize + normalize)
    arr_orig = preprocess_frames(frames_orig, resize_hw=resize)  # (T,C,H,W)
    arr_enh  = preprocess_frames(frames_enh, resize_hw=resize)
    # split into clips
    clips_orig = clips_from_frames(arr_orig, clip_len=clip_len)
    clips_enh  = clips_from_frames(arr_enh, clip_len=clip_len)
    nclips_o = len(clips_orig)
    nclips_e = len(clips_enh)
    log(f"[{view_name}] {nclips_o} clips from original, {nclips_e} clips from enhanced (clip_len={clip_len})")
    if nclips_o == 0 or nclips_e == 0:
        raise RuntimeError(f"[{view_name}] Not enough frames to create any clips (clip_len={clip_len}).")
    # extract features
    feats_o = extract_features_from_clips(model, clips_orig, device, batch_size=batch_size)  # (N,C)
    feats_e = extract_features_from_clips(model, clips_enh, device, batch_size=batch_size)
    # optionally save
    if save_feats_outdir:
        os.makedirs(save_feats_outdir, exist_ok=True)
        np.save(os.path.join(save_feats_outdir, f"{view_name}_orig_feats.npy"), feats_o)
        np.save(os.path.join(save_feats_outdir, f"{view_name}_enh_feats.npy"), feats_e)
    # compute stats
    mu_o, cov_o = compute_stats_from_feats(feats_o)
    mu_e, cov_e = compute_stats_from_feats(feats_e)
    # compute frechet
    fvd_val = frechet_distance(mu_o, cov_o, mu_e, cov_e)
    return float(fvd_val), feats_o, feats_e

# ---------- CLI & main ----------
def parse_args():
    """Parses command-line arguments."""
    p = argparse.ArgumentParser(description="Compute FVD per-view with I3D features")
    p.add_argument("--orig_dir", type=str, default="gt",
                   help="Directory with original videos (containing head.mp4/right_hand.mp4/left_hand.mp4).")
    p.add_argument("--enh_dir", type=str, required=True,
                   help="Directory with enhanced videos (containing head.mp4/right_hand.mp4/left_hand.mp4).")
    p.add_argument("--device", type=str, default="cuda:0",
                   help="Device to use, e.g., 'cuda:0' or 'cpu'.")
    p.add_argument("--no_gpu", action="store_true", help="Force CPU usage (overrides device).")
    p.add_argument("--clip_len", type=int, default=16, help="Length of each clip (in frames), default 16.")
    p.add_argument("--resize", type=int, default=224, help="Resize dimension, default 224 -> 224x224.")
    p.add_argument("--batch_size", type=int, default=64, help="Batch size for inference (clips per batch).")
    p.add_argument("--save_feats_outdir", type=str, default=None, help="If specified, saves features for each view as npy files.")
    return p.parse_args()

def main():
    args = parse_args()
    orig_dir = args.orig_dir
    enh_dir = args.enh_dir
    if args.no_gpu:
        device = torch.device("cpu")
    else:
        device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    log(f"Device: {device}")

    # validate files exist
    views = ["head.mp4", "right_hand.mp4", "left_hand.mp4"]
    view_keys = ["head", "right_hand", "left_hand"]
    for v in views:
        if not os.path.isfile(os.path.join(orig_dir, v)):
            raise FileNotFoundError(f"Original video not found: {os.path.join(orig_dir, v)}")
        if not os.path.isfile(os.path.join(enh_dir, v)):
            raise FileNotFoundError(f"Enhanced video not found: {os.path.join(enh_dir, v)}")

    # load model
    log("Loading model (I3D preferred)...")
    model = get_model(device)
    log(f"Using model: {model.__class__}")

    results = OrderedDict()
    feats_save_dir = args.save_feats_outdir
    for vk, vf in zip(view_keys, views):
        orig_path = os.path.join(orig_dir, vf)
        enh_path = os.path.join(enh_dir, vf)
        log(f"Processing view {vk} ...")
        start = time.time()
        fvd_val, feats_o, feats_e = compute_fvd_for_view(orig_path, enh_path, model, device,
                                                         clip_len=args.clip_len, resize=(args.resize,args.resize),
                                                         batch_size=args.batch_size, save_feats_outdir=feats_save_dir,
                                                         view_name=vk)
        dur = time.time() - start
        log(f"[{vk}] FVD = {fvd_val:.6f}  (time {dur:.1f}s)")
        results[vk] = float(fvd_val)
    # compute mean and std
    vals = np.array(list(results.values()), dtype=np.float64)
    mean_val = float(np.mean(vals))
    std_val = float(np.std(vals, ddof=0))
    log("----- Final Results -----")
    for k, v in results.items():
        log(f"{k}: {v:.6f}")
    log(f"Mean: {mean_val:.6f}")
    log(f"Std : {std_val:.6f}")

if __name__ == "__main__":
    main()
