"""
Computes Cross-View Feature Consistency (CVFC) for multi-view videos.
"""
import os
import cv2
import torch
import argparse
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as T
import clip

def load_video_frames(path, num_frames=300, size=224):
    """
    Loads, samples, and resizes frames from a video file.

    Args:
        path (str): Path to the video file.
        num_frames (int): The number of frames to sample uniformly.
        size (int): The size to resize each frame to (size x size).

    Returns:
        A list of frames as numpy arrays (RGB format).
    """
    cap = cv2.VideoCapture(path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    # Sample from the first 1500 frames for consistency.
    idxs = np.linspace(0, min(1500, total - 1), num_frames, dtype=int)

    frames = []
    # This loop is inefficient but simple. For performance-critical tasks,
    # one might use cap.set() but it can be imprecise.
    for i in range(total):
        ret, frame = cap.read()
        if not ret:
            break
        if i in idxs:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, (size, size), interpolation=cv2.INTER_LINEAR)
            frames.append(frame)
    cap.release()
    return frames

def extract_features(frames, model, preprocess, device):
    """
    Extracts L2-normalized image features using a CLIP model.

    Args:
        frames (list): A list of frames as numpy arrays.
        model: The CLIP model.
        preprocess: The CLIP preprocessing transform.
        device: The device to run the model on.

    Returns:
        A torch tensor of shape (N_frames, D) containing the L2-normalized features.
    """
    tensors = torch.stack([preprocess(Image.fromarray(f)) for f in frames]).to(device)
    with torch.no_grad():
        feats = model.encode_image(tensors).float()
        feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats

def compute_cvfc(features_dict):
    """
    Computes the Cross-View Feature Consistency score.

    For each timestamp, it calculates the mean pairwise cosine similarity between
    the features from the three views (head, right_hand, left_hand). The final
    score is the average of these means over all timestamps.

    Args:
        features_dict (dict): A dictionary mapping view names ('head', 'right_hand',
                              'left_hand') to their corresponding feature tensors.
                              Each tensor has shape (N_frames, D).

    Returns:
        The final CVFC score as a float.
    """
    # For each timestamp, compute pairwise similarity across the three views.
    scores = []
    for f_head, f_rh, f_lh in zip(features_dict["head"], features_dict["right_hand"], features_dict["left_hand"]):
        sims = []
        sims.append(F.cosine_similarity(f_head.unsqueeze(0), f_rh.unsqueeze(0)).item())
        sims.append(F.cosine_similarity(f_head.unsqueeze(0), f_lh.unsqueeze(0)).item())
        sims.append(F.cosine_similarity(f_rh.unsqueeze(0), f_lh.unsqueeze(0)).item())
        scores.append(np.mean(sims))
    return float(np.mean(scores))

if __name__ == "__main__":
    import PIL.Image as Image

    parser = argparse.ArgumentParser(description="Compute Cross-View Feature Consistency (CVFC) between GT and generated videos.")
    parser.add_argument("--gt", type=str, default="gt/",
                        help="Directory containing ground-truth videos (e.g., head.mp4, right_hand.mp4).")
    parser.add_argument("--gen", type=str, required=True,
                        help="Directory containing generated videos (e.g., head.mp4, right_hand.mp4).")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)

    results = {}
    for name, path in {"GT": args.gt, "GEN": args.gen}.items():
        features_dict = {}
        for view in ["head", "right_hand", "left_hand"]:
            video_path = os.path.join(path, f"{view}.mp4")
            frames = load_video_frames(video_path, size=224)
            features = extract_features(frames, model, preprocess, device)
            features_dict[view] = features
        results[name] = compute_cvfc(features_dict)

    print("=== Cross-View Feature Consistency (CVFC) ===")
    for k, v in results.items():
        print(f"{k}: {v:.4f}")
