"""
Computes CLIP-based scores to evaluate video quality and text alignment.
"""

import os
import sys
import argparse
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm
import torch

# expects openai/clip installed (you stated weights already downloaded)
try:
    import clip
except Exception as e:
    raise ImportError("Please install CLIP (https://github.com/openai/CLIP).") from e

# ---------------- utilities (same sampling/reading logic as FVD) ----------------
def log(msg):
    """Prints a message to stdout and flushes the buffer."""
    print(msg)
    sys.stdout.flush()

def get_video_frame_count(path):
    """Returns the total number of frames in a video file."""
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {path}")
    count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return count

def sample_frame_indices(min_len, num_samples=300):
    """Generates a list of uniformly sampled frame indices from a video."""
    if min_len < num_samples:
        raise ValueError(f"min_len ({min_len}) < required samples ({num_samples})")
    idx = np.linspace(0, min_len - 1, num=num_samples, dtype=np.int64)
    return idx.tolist()

def read_frames_at_indices(video_path, indices):
    """Reads specific frames from a video file at given indices."""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video {video_path}")
    frames = []
    for i in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(i))
        ret, frm = cap.read()
        if not ret or frm is None:
            cap.release()
            raise RuntimeError(f"Failed to read frame {i} from {video_path}")
        # Convert OpenCV's BGR format to RGB
        frm_rgb = cv2.cvtColor(frm, cv2.COLOR_BGR2RGB)
        frames.append(frm_rgb)
    cap.release()
    return frames

def frames_to_pil(frames_rgb):
    """Converts a list of RGB frames (NumPy arrays) to PIL Images."""
    return [Image.fromarray(f) for f in frames_rgb]

def clips_from_frames_list(frames_list, clip_len=16):
    """Divides a list of frames into non-overlapping clips of a fixed length."""
    T = len(frames_list)
    nclips = T // clip_len
    clips = []
    for k in range(nclips):
        s = k * clip_len
        clips.append(frames_list[s:s+clip_len])
    return clips

# ---------------- CLIP feature extraction ----------------
def extract_image_features_clip(model, preprocess, pil_images, device, batch_size=32):
    """
    Extracts L2-normalized image features using a CLIP model.

    Args:
        model: The CLIP model.
        preprocess: The CLIP preprocessing transform.
        pil_images: A list of PIL.Image objects.
        device: The device to run the model on.
        batch_size: The batch size for feature extraction.

    Returns:
        A numpy array of shape (N_images, D) containing the L2-normalized features.
    """
    model.to(device).eval()
    all_feats = []
    with torch.no_grad():
        for i in range(0, len(pil_images), batch_size):
            batch_imgs = pil_images[i:i+batch_size]
            # Preprocess images and move the tensor to the correct device.
            batch_t = torch.stack([preprocess(img) for img in batch_imgs]).to(device)
            feats = model.encode_image(batch_t)
            feats = feats / feats.norm(dim=-1, keepdim=True)
            all_feats.append(feats.cpu().numpy())
    if len(all_feats) == 0:
        return np.zeros((0, model.visual.output_dim if hasattr(model, 'visual') else 512), dtype=np.float32)
    return np.concatenate(all_feats, axis=0)

def compute_pairwise_upper_triangle_mean_cosine(feats_normed):
    """
    Computes the mean cosine similarity over the upper triangle of a similarity matrix.

    This is used to calculate the self-similarity within a set of features.

    Args:
        feats_normed: A numpy array of shape (N, D) with L2-normalized features.

    Returns:
        The mean cosine similarity.
    """
    N = feats_normed.shape[0]
    if N < 2:
        return 0.0
    sim_mat = feats_normed @ feats_normed.T  # Cosine similarity for normalized vectors
    # Take the upper triangle, excluding the diagonal.
    iu = np.triu_indices(N, k=1)
    vals = sim_mat[iu]
    return float(np.mean(vals))

# ---------------- main compute for one view ----------------
def compute_clip_score_for_view(orig_path, enh_path, clip_model, preprocess, device,
                                clip_len=16, batch_size=32, save_feats_outdir=None,
                                text_prompts=None, view_name="view"):
    """
    Computes all CLIP-based scores for a single view.

    This function handles video loading, frame sampling, feature extraction,
    and metric calculation for a pair of original and enhanced videos.

    Args:
        orig_path (str): Path to the original video file.
        enh_path (str): Path to the enhanced video file.
        clip_model: The CLIP model.
        preprocess: The CLIP preprocessing transform.
        device: The device to run computations on.
        clip_len (int): Number of frames per clip.
        batch_size (int): Batch size for CLIP inference.
        save_feats_outdir (str, optional): Directory to save computed features. Defaults to None.
        text_prompts (list of str, optional): Text prompts for CLIP score calculation. Defaults to None.
        view_name (str): Name of the view for logging purposes.

    Returns:
        A tuple containing:
        - A dictionary with all computed metrics.
        - A numpy array of original video clip features.
        - A numpy array of enhanced video clip features.
    """
    n_orig = get_video_frame_count(orig_path)
    n_enh = get_video_frame_count(enh_path)
    min_len = min(1500, n_orig, n_enh)
    if min_len < 300:
        raise RuntimeError(f"[{view_name}] min(1500, n_orig={n_orig}, n_enh={n_enh}) = {min_len} < 300. Cannot proceed.")
    indices = sample_frame_indices(min_len, num_samples=300)
    log(f"[{view_name}] Reading frames from original: {orig_path}")
    frames_orig = read_frames_at_indices(orig_path, indices)
    log(f"[{view_name}] Reading frames from enhanced: {enh_path}")
    frames_enh  = read_frames_at_indices(enh_path, indices)

    # Convert frames to PIL images and group into clips.
    pil_orig = frames_to_pil(frames_orig)
    pil_enh  = frames_to_pil(frames_enh)
    clips_orig = clips_from_frames_list(pil_orig, clip_len=clip_len)
    clips_enh  = clips_from_frames_list(pil_enh, clip_len=clip_len)
    nclips = min(len(clips_orig), len(clips_enh))
    if nclips == 0:
        raise RuntimeError(f"[{view_name}] Not enough frames to form any clip (clip_len={clip_len}).")

    log(f"[{view_name}] {len(clips_orig)} clips orig, {len(clips_enh)} clips enh, using first {nclips} paired clips")

    # For efficiency, flatten all clips into a single list of frames to process in batches.
    def flatten_clips(clips):
        """Flattens a list of clips into a list of frames and maps frames to clip indices."""
        flat = []
        clip_indices = []
        for ci, c in enumerate(clips[:nclips]):
            for _ in c:
                clip_indices.append(ci)
            flat.extend(clips[ci])
        return flat, np.array(clip_indices, dtype=np.int32)

    flat_orig, clip_idx_orig = flatten_clips(clips_orig)
    flat_enh, clip_idx_enh   = flatten_clips(clips_enh)

    # Extract frame-level features.
    log(f"[{view_name}] Extracting frame features for original ({len(flat_orig)} frames)...")
    feats_orig_frames = extract_image_features_clip(clip_model, preprocess, flat_orig, device, batch_size=batch_size)
    log(f"[{view_name}] Extracting frame features for enhanced ({len(flat_enh)} frames)...")
    feats_enh_frames  = extract_image_features_clip(clip_model, preprocess, flat_enh, device, batch_size=batch_size)

    # Aggregate frame features into clip features by mean pooling, then re-normalize.
    D = feats_orig_frames.shape[1]
    feats_orig_clips = np.zeros((nclips, D), dtype=np.float32)
    feats_enh_clips  = np.zeros((nclips, D), dtype=np.float32)
    for ci in range(nclips):
        idxs_o = np.where(clip_idx_orig == ci)[0]
        idxs_e = np.where(clip_idx_enh  == ci)[0]
        feats_orig_clips[ci] = np.mean(feats_orig_frames[idxs_o], axis=0)
        feats_enh_clips[ci]  = np.mean(feats_enh_frames[idxs_e], axis=0)
    
    def l2_normalize_rows(x):
        """L2 normalizes the rows of a 2D numpy array."""
        norms = np.linalg.norm(x, axis=1, keepdims=True)
        norms[norms == 0] = 1.0
        return x / norms
    feats_orig_clips = l2_normalize_rows(feats_orig_clips)
    feats_enh_clips  = l2_normalize_rows(feats_enh_clips)

    if save_feats_outdir:
        os.makedirs(save_feats_outdir, exist_ok=True)
        np.save(os.path.join(save_feats_outdir, f"{view_name}_orig_clip_feats.npy"), feats_orig_clips)
        np.save(os.path.join(save_feats_outdir, f"{view_name}_enh_clip_feats.npy"), feats_enh_clips)

    # --- METRIC CALCULATIONS ---

    # 1. Inter-clip similarity (original vs. enhanced).
    aligned_cosines = np.sum(feats_orig_clips * feats_enh_clips, axis=1)
    inter_clip_similarity = float(np.mean(aligned_cosines))

    # 2. Intra-clip self-similarity (internal consistency).
    real_self = compute_pairwise_upper_triangle_mean_cosine(feats_orig_clips)
    fake_self = compute_pairwise_upper_triangle_mean_cosine(feats_enh_clips)

    # 3. Average L2 feature distance between corresponding clips.
    dists = np.linalg.norm(feats_orig_clips - feats_enh_clips, axis=1)
    avg_feature_distance = float(np.mean(dists))

    results = {
        "inter_clip_similarity": inter_clip_similarity,
        "real_self_similarity": real_self,
        "fake_self_similarity": fake_self,
        "avg_feature_distance": avg_feature_distance,
        "n_clips": int(nclips)
    }

    # 4. Text-alignment CLIP Score (if prompts are provided).
    if text_prompts is not None and len(text_prompts) > 0:
        # Encode text prompts and average their features.
        with torch.no_grad():
            text_tokens = clip.tokenize(text_prompts).to(device)
            text_feats = clip_model.encode_text(text_tokens)
            text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
            avg_text = text_feats.mean(dim=0, keepdim=True)
            avg_text = avg_text / avg_text.norm()
            avg_text_np = avg_text.cpu().numpy().reshape(-1)

        # Compute similarity between each video clip and the average text feature.
        sim_real = feats_orig_clips @ avg_text_np
        sim_fake = feats_enh_clips  @ avg_text_np
        # Scale by 100 as is conventional for CLIP Scores.
        clip_score_real = float(np.mean(sim_real) * 100.0)
        clip_score_fake = float(np.mean(sim_fake) * 100.0)
        results["clip_score_real"] = clip_score_real
        results["clip_score_fake"] = clip_score_fake
        results["clip_score_difference"] = abs(clip_score_real - clip_score_fake)

    return results, feats_orig_clips, feats_enh_clips

# ---------------- CLI ----------------
def parse_args():
    """Parses and returns command-line arguments."""
    p = argparse.ArgumentParser(description="Compute CLIP-based scores between original and enhanced videos (3 views)")
    p.add_argument("--orig_dir", type=str, default="gt/",
                   help="Directory containing original videos (e.g., head.mp4, right_hand.mp4).")
    p.add_argument("--enh_dir", type=str, required=True,
                   help="Directory containing enhanced videos (e.g., head.mp4, right_hand.mp4).")
    p.add_argument("--model_name", type=str, default="ViT-B/32",
                   choices=["RN50", "RN101", "ViT-B/32", "ViT-B/16"],
                   help="Name of the CLIP model to use.")
    p.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0' or 'cpu'). Auto-detected by default.")
    p.add_argument("--clip_len", type=int, default=16, help="Number of frames per clip.")
    p.add_argument("--batch_size", type=int, default=64, help="Batch size for CLIP model inference.")
    p.add_argument("--save_feats_outdir", type=str, default=None, help="If provided, save clip features to this directory.")
    p.add_argument("--text_prompts", type=str, nargs='*', default=None, help="Optional text prompts to compute CLIP score.")
    return p.parse_args()

def main():
    """Main function to run the CLIP score calculation pipeline."""
    args = parse_args()
    device = torch.device(args.device if args.device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
    log(f"Device: {device}")
    orig_dir = args.orig_dir
    enh_dir = args.enh_dir
    views = ["head.mp4", "right_hand.mp4", "left_hand.mp4"]
    view_keys = ["head", "right_hand", "left_hand"]

    # Verify that all required video files exist.
    for v in views:
        if not os.path.isfile(os.path.join(orig_dir, v)):
            raise FileNotFoundError(f"Original video not found: {os.path.join(orig_dir, v)}")
        if not os.path.isfile(os.path.join(enh_dir, v)):
            raise FileNotFoundError(f"Enhanced video not found: {os.path.join(enh_dir, v)}")

    # Load the specified CLIP model.
    log(f"Loading CLIP model {args.model_name} ...")
    clip_model, preprocess = clip.load(args.model_name, device=device)
    clip_model.eval()

    per_view_results = {}
    feats_save_dir = args.save_feats_outdir

    # Process each view.
    for vk, vf in zip(view_keys, views):
        orig_path = os.path.join(orig_dir, vf)
        enh_path = os.path.join(enh_dir, vf)
        log(f"Processing view {vk} ...")
        res, feats_o, feats_e = compute_clip_score_for_view(orig_path, enh_path, clip_model, preprocess,
                                                            device=device, clip_len=args.clip_len,
                                                            batch_size=args.batch_size,
                                                            save_feats_outdir=feats_save_dir,
                                                            text_prompts=args.text_prompts,
                                                            view_name=vk)
        per_view_results[vk] = res
        log(f"[{vk}] inter_clip_similarity: {res['inter_clip_similarity']:.6f}, "
            f"real_self: {res['real_self_similarity']:.6f}, fake_self: {res['fake_self_similarity']:.6f}, "
            f"avg_feat_dist: {res['avg_feature_distance']:.6f}")

        if "clip_score_real" in res:
            log(f"[{vk}] clip_score_real: {res['clip_score_real']:.3f}, clip_score_fake: {res['clip_score_fake']:.3f}, diff: {res['clip_score_difference']:.3f}")

    # Aggregate results across all views.
    def gather_metric(name):
        """Gathers a specific metric from all view results."""
        vals = []
        for vk in view_keys:
            v = per_view_results[vk].get(name, None)
            if v is not None:
                vals.append(v)
        return np.array(vals, dtype=np.float64) if len(vals) > 0 else None

    for metric in ["inter_clip_similarity", "avg_feature_distance", "real_self_similarity", "fake_self_similarity"]:
        arr = gather_metric(metric)
        if arr is not None:
            log(f"[AGG] {metric}: mean={arr.mean():.6f}, std={arr.std(ddof=0):.6f}")

    # Aggregate text-based CLIP scores.
    if args.text_prompts:
        arr_real = gather_metric("clip_score_real")
        arr_fake = gather_metric("clip_score_fake")
        if arr_real is not None:
            log(f"[AGG] clip_score_real: mean={arr_real.mean():.3f}, std={arr_real.std(ddof=0):.3f}")
        if arr_fake is not None:
            log(f"[AGG] clip_score_fake: mean={arr_fake.mean():.3f}, std={arr_fake.std(ddof=0):.3f}")
        if arr_real is not None and arr_fake is not None:
            diffs = np.abs(arr_real - arr_fake)
            log(f"[AGG] clip_score_difference: mean={diffs.mean():.3f}, std={diffs.std(ddof=0):.3f}")

if __name__ == "__main__":
    main()
