"""
Computes Frechet Inception Distance (FID) between original and enhanced videos.
"""
import os
import sys
import argparse
import math
import numpy as np
import cv2
import torch
import time
from collections import OrderedDict

try:
    from scipy.linalg import sqrtm
except Exception:
    sqrtm = None

import torch.nn as nn

# ---------- Utilities ----------
def log(msg):
    """Prints a message to stdout."""
    print(msg)
    sys.stdout.flush()

def get_video_frame_count(path):
    """Gets the total number of frames in a video."""
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {path}")
    count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return count

def sample_frame_indices(min_len, num_samples=300):
    """Generates a list of uniformly sampled frame indices."""
    if min_len < num_samples:
        raise ValueError(f"min_len ({min_len}) < required samples ({num_samples})")
    idx = np.linspace(0, min_len - 1, num=num_samples, dtype=np.int64)
    return idx.tolist()

def read_frames_at_indices(video_path, indices):
    """
    Reads frames at given 0-based frame indices.
    
    Returns:
        list: A list of BGR frames (np.uint8).
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video {video_path}")
    frames = []
    for i in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(i))
        ret, frm = cap.read()
        if not ret or frm is None:
            cap.release()
            raise RuntimeError(f"Failed to read frame {i} from {video_path}")
        frames.append(frm)
    cap.release()
    return frames

def preprocess_frames_for_inception(frames_bgr, resize_hw=(299,299)):
    """
    Converts BGR frames to a normalized RGB tensor for InceptionV3.

    Args:
        frames_bgr (list): List of BGR uint8 frames.
    
    Returns:
        np.ndarray: A (T, C, H, W) float32 array, normalized by ImageNet stats.
    """
    H, W = resize_hw
    processed = []
    for f in frames_bgr:
        f_rgb = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
        f_resized = cv2.resize(f_rgb, (W, H), interpolation=cv2.INTER_LINEAR)
        f_resized = f_resized.astype(np.float32) / 255.0
        processed.append(f_resized)
    arr = np.stack(processed, axis=0)
    arr = arr.transpose(0,3,1,2).astype(np.float32)
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1,3,1,1)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1,3,1,1)
    arr = (arr - mean) / std
    return arr

# ---------- FID compute helpers ----------
def compute_stats_from_feats(feats):
    """
    Computes the mean and covariance of a set of features.
    
    Args:
        feats (np.ndarray): (N, D) feature array.
    
    Returns:
        tuple: (mu, cov) as numpy arrays.
    """
    mu = np.mean(feats, axis=0)
    if feats.shape[0] > 1:
        cov = np.cov(feats, rowvar=False)
    else:
        cov = np.zeros((feats.shape[1], feats.shape[1]), dtype=np.float64)
    return mu.astype(np.float64), cov.astype(np.float64)

def sqrtm_newton_schulz(a):
    """Fallback matrix square root via eigen-decomposition."""
    vals, vecs = np.linalg.eigh(a)
    vals[vals < 0] = 0.0
    sqrt_vals = np.sqrt(vals)
    return (vecs * sqrt_vals) @ vecs.T

def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Computes the Frechet distance between two multivariate Gaussians."""
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    assert mu1.shape == mu2.shape
    assert sigma1.shape == sigma2.shape

    diff = mu1 - mu2
    cov_eps = eps * np.eye(sigma1.shape[0], dtype=np.float64)
    sigma1_eps = sigma1 + cov_eps
    sigma2_eps = sigma2 + cov_eps

    if sqrtm is not None:
        prod = sigma1_eps.dot(sigma2_eps)
        covmean = sqrtm(prod)
        if np.iscomplexobj(covmean):
            if not np.allclose(np.imag(covmean), 0, atol=1e-6):
                log("Warning: sqrtm produced complex values with significant imaginary part; taking real part.")
            covmean = np.real(covmean)
    else:
        covmean = sqrtm_newton_schulz(sigma1_eps.dot(sigma2_eps))

    tr_covmean = np.trace(covmean)
    fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2.0 * tr_covmean
    if fid < 0 and fid > -1e-6:
        fid = 0.0
    return float(fid)


def get_inception_model(device):
    """Loads a pretrained InceptionV3 model for feature extraction."""
    try:
        import torchvision
        try:
            Weights = getattr(torchvision.models, "Inception_V3_Weights", None)
            weights = Weights.DEFAULT if Weights else None
            model = torchvision.models.inception_v3(weights=weights, aux_logits=False)
        except Exception:
            model = torchvision.models.inception_v3(pretrained=True, aux_logits=False)

        model.fc = nn.Identity()
        model.eval()
        model = model.to(device)
        return model
    except Exception as e:
        raise RuntimeError("Failed to load torchvision.models.inception_v3. Please install/upgrade torchvision.") from e


def extract_inception_features(model, frames_arr, device, batch_size=64):
    """
    Extracts InceptionV3 features from a batch of preprocessed frames.

    Args:
        model: The InceptionV3 model.
        frames_arr (np.ndarray): (T, C, H, W) float32 normalized frame array.
        device: The device to run inference on.
        batch_size (int): Batch size for inference.

    Returns:
        np.ndarray: (T, D) feature array.
    """
    model_device = device
    model = model
    T = frames_arr.shape[0]
    feats = []
    with torch.no_grad():
        for i in range(0, T, batch_size):
            batch = frames_arr[i:i+batch_size]
            batch_t = torch.from_numpy(batch).to(model_device)
            out = model(batch_t)
            if not isinstance(out, torch.Tensor):
                raise RuntimeError(f"Unexpected Inception output type: {type(out)}")
            feats.append(out.cpu().numpy())
    if len(feats) == 0:
        return np.zeros((0,0), dtype=np.float32)
    feats = np.concatenate(feats, axis=0)
    return feats

# ---------- Main flow for one view ----------
def compute_fid_for_view(orig_path, enh_path, model, device, resize=(299,299), batch_size=64, save_feats_outdir=None, view_name="view"):
    """Computes FID for a single pair of original and enhanced videos."""
    n_orig = get_video_frame_count(orig_path)
    n_enh  = get_video_frame_count(enh_path)
    min_len = min(1500, n_orig, n_enh)
    if min_len < 300:
        raise RuntimeError(f"For view {view_name}: min(1500, n_orig={n_orig}, n_enh={n_enh}) = {min_len} < 300. Cannot proceed.")
    indices = sample_frame_indices(min_len, num_samples=300)
    log(f"[{view_name}] Reading frames from original: {orig_path}")
    frames_orig = read_frames_at_indices(orig_path, indices)
    log(f"[{view_name}] Reading frames from enhanced: {enh_path}")
    frames_enh = read_frames_at_indices(enh_path, indices)

    arr_orig = preprocess_frames_for_inception(frames_orig, resize_hw=resize)
    arr_enh  = preprocess_frames_for_inception(frames_enh, resize_hw=resize)

    log(f"[{view_name}] Extracting Inception features (batch_size={batch_size}) for original...")
    feats_o = extract_inception_features(model, arr_orig, device, batch_size=batch_size)
    log(f"[{view_name}] Extracting Inception features (batch_size={batch_size}) for enhanced...")
    feats_e = extract_inception_features(model, arr_enh, device, batch_size=batch_size)

    if save_feats_outdir:
        os.makedirs(save_feats_outdir, exist_ok=True)
        np.save(os.path.join(save_feats_outdir, f"{view_name}_orig_feats.npy"), feats_o)
        np.save(os.path.join(save_feats_outdir, f"{view_name}_enh_feats.npy"), feats_e)

    mu_o, cov_o = compute_stats_from_feats(feats_o)
    mu_e, cov_e = compute_stats_from_feats(feats_e)

    fid_val = frechet_distance(mu_o, cov_o, mu_e, cov_e)
    return float(fid_val), feats_o, feats_e

# ---------- CLI & main ----------
def parse_args():
    """Parses command-line arguments."""
    p = argparse.ArgumentParser(description="Compute per-view FID using InceptionV3 features.")
    p.add_argument("--orig_dir", type=str, default="gt/",
                   help="Directory with original videos (e.g., head.mp4).")
    p.add_argument("--enh_dir", type=str, required=True,
                   help="Directory with enhanced videos (e.g., head.mp4).")
    p.add_argument("--device", type=str, default="cuda:0",
                   help="Device to use, e.g., 'cuda:0' or 'cpu'.")
    p.add_argument("--no_gpu", action="store_true", help="Force CPU usage, overriding --device.")
    p.add_argument("--resize", type=int, default=299, help="Frame resize dimension (default: 299 for 299x299).")
    p.add_argument("--batch_size", type=int, default=64, help="Batch size for feature extraction.")
    p.add_argument("--save_feats_outdir", type=str, default=None, help="If set, save per-view features as .npy files to this directory.")
    return p.parse_args()

def main():
    args = parse_args()
    orig_dir = args.orig_dir
    enh_dir  = args.enh_dir
    if args.no_gpu:
        device = torch.device("cpu")
    else:
        device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    log(f"Device: {device}")

    views = ["head.mp4", "right_hand.mp4", "left_hand.mp4"]
    view_keys = ["head", "right_hand", "left_hand"]
    for v in views:
        if not os.path.isfile(os.path.join(orig_dir, v)):
            raise FileNotFoundError(f"Original video not found: {os.path.join(orig_dir, v)}")
        if not os.path.isfile(os.path.join(enh_dir, v)):
            raise FileNotFoundError(f"Enhanced video not found: {os.path.join(enh_dir, v)}")

    log("Loading InceptionV3 model (torchvision pretrained)...")
    model = get_inception_model(device)
    log(f"Using model: torchvision.models.inception_v3 with fc -> Identity (outputs 2048-d)")

    results = OrderedDict()
    feats_save_dir = args.save_feats_outdir
    for vk, vf in zip(view_keys, views):
        orig_path = os.path.join(orig_dir, vf)
        enh_path  = os.path.join(enh_dir, vf)
        log(f"Processing view {vk} ...")
        start = time.time()
        fid_val, feats_o, feats_e = compute_fid_for_view(orig_path, enh_path, model, device,
                                                         resize=(args.resize, args.resize),
                                                         batch_size=args.batch_size,
                                                         save_feats_outdir=feats_save_dir,
                                                         view_name=vk)
        dur = time.time() - start
        log(f"[{vk}] FID = {fid_val:.6f}  (time {dur:.1f}s)")
        results[vk] = float(fid_val)

    vals = np.array(list(results.values()), dtype=np.float64)
    mean_val = float(np.mean(vals))
    std_val  = float(np.std(vals, ddof=0))
    log("----- Final Results -----")
    for k, v in results.items():
        log(f"{k}: {v:.6f}")
    log(f"Mean: {mean_val:.6f}")
    log(f"Std : {std_val:.6f}")

if __name__ == "__main__":
    main()
