"""
Computes Temporal Consistency Jitter (TCJ) for a set of multi-view videos.
"""
import os
import cv2
import torch
import argparse
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as T
import clip

def load_video_frames(path, num_frames=300, size=224):
    """
    Samples and loads a specified number of frames from a video file.

    Args:
        path (str): The path to the video file.
        num_frames (int): The number of frames to sample uniformly.
        size (int): The size to resize frames to (size x size).

    Returns:
        list: A list of resized RGB frames as numpy arrays.
    """
    cap = cv2.VideoCapture(path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    idxs = np.linspace(0, min(1500, total - 1), num_frames, dtype=int)

    frames = []
    for i in range(total):
        ret, frame = cap.read()
        if not ret:
            break
        if i in idxs:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, (size, size), interpolation=cv2.INTER_LINEAR)
            frames.append(frame)
    cap.release()
    return frames

def extract_features(frames, model, preprocess, device):
    """
    Extracts CLIP features for a list of frames.

    Args:
        frames (list): A list of RGB frames.
        model: The CLIP model.
        preprocess: The CLIP preprocessing transform.
        device: The device for torch computations.

    Returns:
        torch.Tensor: A tensor of normalized CLIP features.
    """
    tensors = torch.stack([preprocess(Image.fromarray(f)) for f in frames]).to(device)
    with torch.no_grad():
        feats = model.encode_image(tensors).float()
        feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats

def compute_tcj(features_dict):
    """
    Computes the Temporal Consistency Jitter (TCJ) score.

    The score is the mean variance of cosine similarities between consecutive
    frame features, averaged over three views (head, right_hand, left_hand).

    Args:
        features_dict (dict): A dictionary containing feature tensors for
            'head', 'right_hand', and 'left_hand' views.

    Returns:
        float: The final TCJ score.
    """
    variances = []
    for view in ["head", "right_hand", "left_hand"]:
        feats = features_dict[view]
        sims = []
        for i in range(len(feats) - 1):
            sims.append(F.cosine_similarity(feats[i].unsqueeze(0), feats[i+1].unsqueeze(0)).item())
        variances.append(np.var(sims))
    return float(np.mean(variances))

if __name__ == "__main__":
    import PIL.Image as Image

    parser = argparse.ArgumentParser(description="Compute Temporal Consistency Jitter (TCJ).")
    parser.add_argument("--gt", type=str, default="./ground_truth",
                        help="Directory with ground-truth videos (e.g., head.mp4).")
    parser.add_argument("--gen", type=str, required=True,
                        help="Directory with generated videos to evaluate (e.g., head.mp4).")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)

    results = {}
    for name, path in {"GT": args.gt, "GEN": args.gen}.items():
        features_dict = {}
        for view in ["head", "right_hand", "left_hand"]:
            video_path = os.path.join(path, f"{view}.mp4")
            frames = load_video_frames(video_path, size=224)
            features = extract_features(frames, model, preprocess, device)
            features_dict[view] = features
        results[name] = compute_tcj(features_dict)

    print("=== Temporal Consistency Jitter (TCJ) ===")
    for k, v in results.items():
        print(f"{k}: {v:.6f}")
