"""
A comprehensive tool for comparing video feature distributions across multiple groups.
"""
import os, sys, argparse, time, json
from collections import OrderedDict

import numpy as np
import cv2
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
from sklearn.manifold import TSNE
from scipy.linalg import sqrtm
from scipy.special import logsumexp
from scipy.stats import entropy

# ---------- Utility Functions ----------
def log(*args, **kwargs):
    """Prints a log message and flushes the buffer immediately."""
    print(*args, **kwargs, flush=True)

def list_mp4s(dirpath):
    """Lists all .mp4 files in a directory."""
    return sorted([f for f in os.listdir(dirpath) if f.lower().endswith(".mp4")])

def get_video_frame_count(path):
    """Gets the total number of frames in a video."""
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {path}")
    cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return cnt

def read_random_n_frames_normalized(video_path, n=300, resize_hw=(299,299), norm_mean=None, norm_std=None, seed=None):
    """
    Reads N random frames from a video, resizes, and normalizes them.

    Args:
        video_path (str): Path to the video file.
        n (int): Number of frames to sample.
        resize_hw (tuple): Target height and width for resizing.
        norm_mean (list): Mean for normalization.
        norm_std (list): Standard deviation for normalization.
        seed (int, optional): Random seed for frame selection.

    Returns:
        np.ndarray: A numpy array of shape (n, 3, H, W) with dtype float32.
    """
    if norm_mean is None:
        norm_mean = [0.485, 0.456, 0.406]
    if norm_std is None:
        norm_std = [0.229, 0.224, 0.225]

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {video_path}")

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total_frames < n:
        cap.release()
        raise RuntimeError(f"Video {video_path} has only {total_frames} frames, but {n} are required.")

    if seed is not None:
        np.random.seed(seed)
        
    random_indices = np.random.choice(total_frames, n, replace=False)
    sorted_indices = np.sort(random_indices)

    frames = []
    for frame_idx in sorted_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frm = cap.read()
        if not ret:
            log(f"Warning: Failed to read frame {frame_idx} from {video_path}. Skipping.")
            continue
        
        frm_rgb = cv2.cvtColor(frm, cv2.COLOR_BGR2RGB)
        frm_resized = cv2.resize(frm_rgb, resize_hw, interpolation=cv2.INTER_LINEAR)
        frm_resized = frm_resized.astype(np.float32) / 255.0
        frames.append(frm_resized)
    
    cap.release()
    
    if len(frames) < n:
         raise RuntimeError(f"Could only read {len(frames)} frames from {video_path} (required {n}).")

    arr = np.stack(frames, axis=0)
    arr = arr.transpose(0,3,1,2)
    mean = np.array(norm_mean, dtype=np.float32).reshape(1,3,1,1)
    std  = np.array(norm_std, dtype=np.float32).reshape(1,3,1,1)
    arr = (arr - mean) / std
    return arr

def read_first_n_frames_normalized(video_path, n=300, resize_hw=(299,299), norm_mean=None, norm_std=None):
    """
    Reads the first N frames from a video, resizes, and normalizes them.

    Args:
        video_path (str): Path to the video file.
        n (int): Number of frames to read.
        resize_hw (tuple): Target height and width for resizing.
        norm_mean (list): Mean for normalization.
        norm_std (list): Standard deviation for normalization.

    Returns:
        np.ndarray: A numpy array of shape (n, 3, H, W) with dtype float32.
    """
    if norm_mean is None:
        norm_mean = [0.485, 0.456, 0.406]
    if norm_std is None:
        norm_std = [0.229, 0.224, 0.225]

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {video_path}")
    frames = []
    read_count = 0
    for i in range(n):
        ret, frm = cap.read()
        if not ret:
            cap.release()
            raise RuntimeError(f"Video {video_path} has fewer than {n} frames (found {read_count}).")
        frm_rgb = cv2.cvtColor(frm, cv2.COLOR_BGR2RGB)
        frm_resized = cv2.resize(frm_rgb, resize_hw, interpolation=cv2.INTER_LINEAR)
        frm_resized = frm_resized.astype(np.float32) / 255.0
        frames.append(frm_resized)
        read_count += 1
    cap.release()
    arr = np.stack(frames, axis=0)
    arr = arr.transpose(0,3,1,2)
    mean = np.array(norm_mean, dtype=np.float32).reshape(1,3,1,1)
    std  = np.array(norm_std, dtype=np.float32).reshape(1,3,1,1)
    arr = (arr - mean) / std
    return arr

# ---------- Model Loading & Feature Extraction ----------
class ClipWrapper(nn.Module):
    """A wrapper to make the CLIP model interface consistent with other models."""
    def __init__(self, clip_model):
        super().__init__()
        self.clip_model = clip_model
    def forward(self, x):
        return self.clip_model.encode_image(x)

def get_model_and_config(model_name, device):
    """
    Loads a pretrained model and returns it along with its configuration.
    
    Returns:
        tuple: (model, config_dict)
    """
    log(f"Loading model: {model_name}...")
    if model_name == 'inceptionv3':
        import torchvision
        try:
            Weights = getattr(torchvision.models, "Inception_V3_Weights", None)
            weights = Weights.DEFAULT if Weights else None
            model = torchvision.models.inception_v3(weights=weights, aux_logits=False)
        except Exception:
            model = torchvision.models.inception_v3(pretrained=True, aux_logits=False)
        
        model.fc = nn.Identity()
        model.to(device)
        model.eval()
        
        config = {
            "input_size": 299,
            "norm_mean": [0.485, 0.456, 0.406],
            "norm_std": [0.229, 0.224, 0.225],
            "feature_dim": 2048,
        }
        return model, config

    elif model_name.startswith('clip_'):
        try:
            import clip
        except ImportError:
            raise RuntimeError("CLIP models require `clip-by-openai`. Run: pip install git+https://github.com/openai/CLIP.git")
        
        clip_model_name_map = {
            'clip_vit_b_32': 'ViT-B/32',
            'clip_vit_b_16': 'ViT-B/16',
            'clip_vit_l_14': 'ViT-L/14',
            'clip_rn50': 'RN50',
            'clip_rn101': 'RN101',
            'clip_rn50x4': 'RN50x4',
            'clip_rn50x16': 'RN50x16',
            'clip_rn50x64': 'RN50x64',
            'clip_vit_l_14_336': 'ViT-L/14@336px',
        }
        if model_name not in clip_model_name_map:
            raise ValueError(f"Unsupported CLIP model name: {model_name}")
        clip_model_name = clip_model_name_map[model_name]
        model, preprocess = clip.load(clip_model_name, device=device)
        
        input_size = preprocess.transforms[0].size
        norm_mean = preprocess.transforms[-1].mean
        norm_std = preprocess.transforms[-1].std
        feature_dim = model.visual.output_dim
        
        wrapped_model = ClipWrapper(model)
        wrapped_model.to(device)
        wrapped_model.eval()
        
        config = {
            "input_size": input_size,
            "norm_mean": norm_mean,
            "norm_std": norm_std,
            "feature_dim": feature_dim,
        }
        return wrapped_model, config

    elif model_name.startswith('dinov2_'):
        try:
            import timm
        except ImportError:
            raise RuntimeError("DINOv2 models require `timm`. Run: pip install timm")
            
        timm_model_name_map = {
            'dinov2_vits14': 'vit_small_patch14_dinov2.lvd142m',
            'dinov2_vitb14': 'vit_base_patch14_dinov2.lvd142m',
            'dinov2_vitl14': 'vit_large_patch14_dinov2.lvd142m',
        }
        if model_name not in timm_model_name_map:
            raise ValueError(f"Unknown DINOv2 model: {model_name}. Choices: {list(timm_model_name_map.keys())}")
    
        timm_model_name = timm_model_name_map[model_name]
    
        model = timm.create_model(timm_model_name, pretrained=False, num_classes=0)
    
        model.to(device)
        model.eval()
        
        data_config = model.default_cfg
        config = {
            "input_size": data_config['input_size'][-1],
            "norm_mean": data_config['mean'],
            "norm_std": data_config['std'],
            "feature_dim": model.embed_dim,
        }
        return model, config
        
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

def extract_feats_from_frames_arr(model, frames_arr, device, batch_size=64):
    """
    Extracts features from an array of frames.

    Args:
        model: The feature extraction model.
        frames_arr (np.ndarray): (T, 3, H, W) float32 numpy array.
        device: The device to run inference on.
        batch_size (int): Batch size for inference.

    Returns:
        np.ndarray: (T, D) numpy feature array.
    """
    T = frames_arr.shape[0]
    out_list = []
    with torch.no_grad():
        for i in range(0, T, batch_size):
            batch = frames_arr[i:i+batch_size]
            batch_t = torch.from_numpy(batch).to(device=device, dtype=torch.float32)
            
            feats = model(batch_t)
            
            # InceptionV3 returns a tuple in training mode
            if isinstance(feats, tuple):
                feats = feats[0]
            
            # CLIP features need to be normalized
            if isinstance(model, ClipWrapper):
                feats = feats / feats.norm(dim=-1, keepdim=True)

            feats_np = feats.cpu().numpy()
            out_list.append(feats_np)
            
    if len(out_list) == 0:
        return np.zeros((0,0), dtype=np.float32)
        
    return np.concatenate(out_list, axis=0)


# ---------- Statistics and Distance Calculations ----------
def compute_stats_from_feats(feats):
    """Computes mean and covariance from a set of features."""
    mu = np.mean(feats, axis=0)
    if feats.shape[0] > 1:
        cov = np.cov(feats, rowvar=False)
    else:
        cov = np.zeros((feats.shape[1], feats.shape[1]), dtype=np.float64)
    return mu.astype(np.float64), cov.astype(np.float64)

def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Computes Frechet distance between two multivariate Gaussians."""
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    diff = mu1 - mu2
    cov_eps = eps * np.eye(sigma1.shape[0], dtype=np.float64)
    sigma1_eps = sigma1 + cov_eps
    sigma2_eps = sigma2 + cov_eps
    try:
        prod = sigma1_eps.dot(sigma2_eps)
        covmean = sqrtm(prod)
        if np.iscomplexobj(covmean):
            covmean = np.real(covmean)
    except Exception:
        vals, vecs = np.linalg.eigh(sigma1_eps.dot(sigma2_eps))
        vals[vals < 0] = 0.0
        covmean = (vecs * np.sqrt(vals)) @ vecs.T
    tr_covmean = np.trace(covmean)
    fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2.0 * tr_covmean
    if fid < 0 and fid > -1e-6:
        fid = 0.0
    return float(fid)

def bhattacharyya_distance(mu1, sigma1, mu2, sigma2, eps=1e-8):
    """Computes Bhattacharyya distance between two multivariate Gaussians."""
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    S1 = np.atleast_2d(sigma1).astype(np.float64)
    S2 = np.atleast_2d(sigma2).astype(np.float64)
    d = mu1.shape[0]

    S1_reg = S1 + eps * np.eye(d)
    S2_reg = S2 + eps * np.eye(d)
    S = 0.5 * (S1_reg + S2_reg)

    diff = mu1 - mu2

    try:
        invS = np.linalg.inv(S)
    except np.linalg.LinAlgError:
        invS = np.linalg.pinv(S)

    term1 = 0.125 * diff.T.dot(invS).dot(diff)

    sign_S, logdet_S = np.linalg.slogdet(S)
    sign_S1, logdet_S1 = np.linalg.slogdet(S1_reg)
    sign_S2, logdet_S2 = np.linalg.slogdet(S2_reg)

    if sign_S < 1 or sign_S1 < 1 or sign_S2 < 1:
         term2 = 0.0
    else:
         term2 = 0.5 * (logdet_S - 0.5 * (logdet_S1 + logdet_S2))

    bd = float(term1 + term2)
    return max(0.0, bd)

def hellinger_from_bhattacharyya(bd):
    """Converts Bhattacharyya distance to Hellinger distance and BC."""
    bc = np.exp(-bd)
    h = np.sqrt(max(0.0, 1.0 - bc))
    return float(h), float(bc)

def mahalanobis_mean_distance(mu1, sigma1, mu2, sigma2, eps=1e-8):
    """Computes the squared Mahalanobis distance between two means."""
    d = mu1 - mu2
    S = 0.5 * (np.atleast_2d(sigma1) + np.atleast_2d(sigma2)) + eps * np.eye(len(d))
    try:
        invS = np.linalg.inv(S)
    except Exception:
        invS = np.linalg.pinv(S)
    val = float(d.T.dot(invS).dot(d))
    return val

def gaussian_overlap_mc(mu1, sigma1, mu2, sigma2, n_samples=20000, seed=0):
    """Estimates Gaussian overlap using Monte Carlo sampling."""
    rng = np.random.RandomState(seed)
    dim = mu1.shape[0]
    n_half = n_samples // 2
    Xp = rng.multivariate_normal(mu1, sigma1 + 1e-6*np.eye(dim), size=n_half)
    Xq = rng.multivariate_normal(mu2, sigma2 + 1e-6*np.eye(dim), size=n_half)
    X = np.vstack([Xp, Xq])
    def mvn_logpdf(X, mu, cov):
        from scipy.stats import multivariate_normal
        mv = multivariate_normal(mean=mu, cov=cov, allow_singular=True)
        return mv.logpdf(X)
    logp = mvn_logpdf(X, mu1, sigma1)
    logq = mvn_logpdf(X, mu2, sigma2)
    p_norm = np.exp(logp - logsumexp(logp))
    q_norm = np.exp(logq - logsumexp(logq))
    ovl = np.sum(np.minimum(p_norm, q_norm))
    return float(ovl)

# ---------- GMM-related Calculations ----------
def fit_gmm(X, n_components=16, random_state=0):
    """Fits a Gaussian Mixture Model to the data."""
    gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=random_state, reg_covar=1e-2, max_iter=500, init_params='k-means++')
    gmm.fit(X)
    return gmm

def jsd_between_gmms_mc(gmm_p, gmm_q, n_samples=10000):
    """Estimates Jensen-Shannon Divergence between two GMMs via Monte Carlo."""
    n_half = max(1, n_samples // 2)
    Xp, _ = gmm_p.sample(n_half)
    Xq, _ = gmm_q.sample(n_half)

    logp_xp = gmm_p.score_samples(Xp)
    logq_xp = gmm_q.score_samples(Xp)
    logm_xp = logsumexp(np.vstack([logp_xp + np.log(0.5), logq_xp + np.log(0.5)]), axis=0)
    kl_p_m = np.mean(logp_xp - logm_xp)

    logp_xq = gmm_p.score_samples(Xq)
    logq_xq = gmm_q.score_samples(Xq)
    logm_xq = logsumexp(np.vstack([logp_xq + np.log(0.5), logq_xq + np.log(0.5)]), axis=0)
    kl_q_m = np.mean(logq_xq - logm_xq)

    jsd = 0.5 * (kl_p_m + kl_q_m)
    return float(jsd)

def kl_between_gmms_mc(gmm_p, gmm_q, n_samples=10000):
    """Estimates KL Divergence between two GMMs via Monte Carlo."""
    n_half = max(1, n_samples // 2)
    Xp, _ = gmm_p.sample(n_half)
    Xq, _ = gmm_q.sample(n_half)

    logp_xp = gmm_p.score_samples(Xp)
    logq_xp = gmm_q.score_samples(Xp)
    kl_pq = np.mean(logp_xp - logq_xp)

    logq_xq = gmm_q.score_samples(Xq)
    logp_xq = gmm_p.score_samples(Xq)
    kl_qp = np.mean(logq_xq - logp_xq)
    return float(kl_pq), float(kl_qp)

def gmm_component_center_stats(gmm_ref, gmm_other):
    """Computes statistics on distances between GMM component means."""
    mu0 = np.atleast_2d(gmm_ref.means_)
    mu1 = np.atleast_2d(gmm_other.means_)
    dists = np.sqrt(((mu0[:, None, :] - mu1[None, :, :]) ** 2).sum(axis=2)).ravel()
    return float(np.mean(dists)), float(np.var(dists)), float(np.min(dists)), float(np.max(dists)), dists

# ---------- Visualization ----------
def visualize_2d_embedding_multi(pca_groups, group_labels, out_png, method='umap'):
    """Visualizes 2D embeddings for multiple groups of features."""
    all_X = np.vstack(pca_groups)
    labels = np.hstack([i * np.ones(len(feats)) for i, feats in enumerate(pca_groups)])
    colors = plt.cm.get_cmap('tab10', len(group_labels))

    if method == 'umap':
        try:
            import umap
            reducer = umap.UMAP(n_components=2, random_state=42)
            emb = reducer.fit_transform(all_X)
        except ImportError:
            log("UMAP library not found, falling back to t-SNE...")
            emb = TSNE(n_components=2, init='pca', random_state=42).fit_transform(all_X)
    else:
        emb = TSNE(n_components=2, init='pca', random_state=42).fit_transform(all_X)

    plt.figure(figsize=(10, 8), dpi=200)
    for i in range(len(group_labels)):
        mask = labels == i
        plt.scatter(emb[mask, 0], emb[mask, 1], s=5, alpha=0.2, label=group_labels[i], color=colors(i), rasterized=True)
    plt.title("2D Embedding of Video Feature Distributions")
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(out_png, bbox_inches='tight', dpi=200)
    plt.close()
    log(f"Visualization saved to {out_png}")


def visualize_marginal_distribution(pca_groups, group_labels, pca_model, out_png, dim_idx=0):
    """Visualizes the marginal distribution on a specific PCA component."""
    plt.figure(figsize=(10, 6), dpi=200)
    colors = plt.cm.get_cmap('tab10', len(group_labels))

    for i, feats in enumerate(pca_groups):
        data = feats[:, dim_idx]
        plt.hist(data, bins=50, alpha=0.5, label=group_labels[i], color=colors(i), density=True)

        mu, sigma = np.mean(data), np.std(data)
        x = np.linspace(data.min(), data.max(), 1000)
        from scipy.stats import norm
        pdf = norm.pdf(x, mu, sigma)
        plt.plot(x, pdf, color=colors(i), linestyle='--', linewidth=2,
                 label=f'{group_labels[i]} Normal Fit (μ={mu:.2f}, σ={sigma:.2f})')

    plt.title(f"Marginal Distribution on PCA Component {dim_idx}")
    plt.xlabel("Value")
    plt.ylabel("Density")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(out_png, bbox_inches='tight', dpi=200)
    plt.close()
    log(f"Marginal distribution plot saved to {out_png}")


# ---------- Main Workflow ----------
def parse_args():
    """Parses command-line arguments."""
    p = argparse.ArgumentParser(description="Multi-group, multi-model video distribution comparison tool.")
    p.add_argument("--group_dirs", type=str, nargs='+', action='append', required=True,
                   help="List of video directories for a group (can be specified multiple times for one group).")
    p.add_argument("--group_labels", type=str, nargs='+', required=True,
                   help="Label for each group (must match the order of --group_dirs).")
    p.add_argument("--group_max_videos", type=int, nargs='*', default=None,
                   help="Max number of videos to process per group (optional, matches --group_dirs order).")
    p.add_argument("--model_name", type=str, default="inceptionv3", 
                   choices=["inceptionv3", "clip_vit_b_32", "dinov2_vitb14"],
                   help="Model to use for feature extraction.")
    p.add_argument("--dinov2_model_path", type=str, default=None, help="Path to local DINOv2 model checkpoint.")
    p.add_argument("--device", type=str, default="cuda:0", help="Device to use: cpu or cuda:0.")
    p.add_argument("--batch_size", type=int, default=32, help="Batch size for feature extraction.")
    p.add_argument("--pca_dim", type=int, default=5, help="PCA dimension for GMM/visualization.")
    p.add_argument("--gmm_components", type=int, default=5, help="Number of GMM components.")
    p.add_argument("--mc_samples", type=int, default=10000, help="Monte Carlo samples for JSD/KL calculation.")
    p.add_argument("--save_feats_outdir", type=str, default=None, help="If provided, save extracted features to this directory.")
    p.add_argument("--vis_method", type=str, default="umap", choices=["umap","tsne"], help="2D visualization dimensionality reduction method.")
    p.add_argument("--out_metrics_json", type=str, default="metrics_results.json", help="Path to save the output metrics JSON file.")
    p.add_argument("--mc_overlap_samples", type=int, default=0, help="If > 0, computes MC overlap estimate for Gaussian fits (num samples).")
    p.add_argument("--circle_coverage_ratio", type=float, default=0.5, help="Coverage ratio for circle visualization (default: 0.5).")
    return p.parse_args()


def collect_mp4s_from_dirs_with_limit(dirs, max_videos_list=None):
    """Collects all .mp4 file paths from a list of directories with optional limits."""
    all_mp4s = []
    if max_videos_list is None:
        max_videos_list = [None] * len(dirs)
    if len(max_videos_list) != len(dirs):
        raise ValueError(f"Number of max_videos ({len(max_videos_list)}) must match number of dirs ({len(dirs)}).")
    for i, dirpath in enumerate(dirs):
        if not os.path.exists(dirpath):
            log(f"Warning: Directory {dirpath} does not exist, skipping.")
            continue
        mp4s = [(dirpath, f) for f in list_mp4s(dirpath)]
        max_count = max_videos_list[i]
        if max_count is not None and len(mp4s) > max_count:
            log(f"Limiting directory {dirpath} to {max_count} videos (from {len(mp4s)}).")
            mp4s = mp4s[:max_count]
        else:
            log(f"Using all {len(mp4s)} videos from {dirpath}.")
        all_mp4s.extend(mp4s)
    return all_mp4s

def main():
    args = parse_args()
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    log(f"Using device: {device}")

    group_dirs = args.group_dirs
    group_labels = args.group_labels
    group_max_videos = args.group_max_videos or [None] * len(group_dirs)

    if len(group_dirs) != len(group_labels):
        raise ValueError("Number of --group_dirs must match number of --group_labels.")
    if len(group_max_videos) != len(group_dirs):
        raise ValueError("Number of --group_max_videos must match number of --group_dirs.")

    # Step 1: Load model and get its configuration
    model, model_config = get_model_and_config(args.model_name, device)
    if args.model_name.startswith('dinov2_') and args.dinov2_model_path:
        log(f"Loading local DINOv2 weights from: {args.dinov2_model_path}")
        state_dict = torch.load(args.dinov2_model_path, map_location='cpu')
        model.load_state_dict(state_dict, strict=True)

    input_size = model_config["input_size"]
    norm_mean = model_config["norm_mean"]
    norm_std = model_config["norm_std"]
    log(f"Model config: input_size={input_size}x{input_size}, feature_dim={model_config['feature_dim']}")

    # Step 2: Collect videos from each group
    all_groups_mp4s = []
    for i, dirs in enumerate(group_dirs):
        mp4s = collect_mp4s_from_dirs_with_limit(dirs, [group_max_videos[i]] * len(dirs))
        all_groups_mp4s.append(mp4s)
        log(f"Group '{group_labels[i]}': Found {len(mp4s)} videos.")

    # Step 3: Extract features
    all_feats_list = []
    for i, group_mp4s in enumerate(all_groups_mp4s):
        feats_list = []
        for dirpath, fname in group_mp4s:
            p = os.path.join(dirpath, fname)
            log(f"[{group_labels[i]}] Reading frames from {p}...")
            try:
                # Use parameters from the model's configuration
                arr = read_first_n_frames_normalized(p, n=210, 
                                                     resize_hw=(input_size, input_size),
                                                     norm_mean=norm_mean, norm_std=norm_std)
                feats = extract_feats_from_frames_arr(model, arr, device, batch_size=args.batch_size)
                log(f"[{group_labels[i]}] {fname}: Extracted {feats.shape[0]} frames -> features {feats.shape}")
                feats_list.append(feats)
                del arr
            except Exception as e:
                log(f"[{group_labels[i]}] Error processing {p}: {e}")
                continue
        if len(feats_list) == 0:
            raise RuntimeError(f"No videos were successfully processed for group '{group_labels[i]}'.")
        all_feats_list.append(np.vstack(feats_list).astype(np.float32))

    # Step 4: Save features if requested
    if args.save_feats_outdir:
        os.makedirs(args.save_feats_outdir, exist_ok=True)
        for i, feats in enumerate(all_feats_list):
            np.save(os.path.join(args.save_feats_outdir, f"feats_{group_labels[i]}_{args.model_name}.npy"), feats)
        log(f"Features saved to {args.save_feats_outdir}")

    # Step 5: Compute FID (using the first group as reference)
    log("Computing FID (reference: first group)...")
    mu0, cov0 = compute_stats_from_feats(all_feats_list[0])
    results = OrderedDict()
    for i in range(len(all_feats_list)):
        mu, cov = compute_stats_from_feats(all_feats_list[i])
        fid = frechet_distance(mu0, cov0, mu, cov)
        results[group_labels[i]] = dict(fid=fid)

    # Step 6: PCA + GMM + JSD/KL
    all_feats_combined = np.vstack(all_feats_list)
    pca_dim = min(args.pca_dim, all_feats_combined.shape[1], all_feats_combined.shape[0])
    log(f"Running PCA to {pca_dim} dimensions...")
    pca = PCA(n_components=pca_dim, svd_solver='randomized', random_state=0)
    all_pca = pca.fit_transform(all_feats_combined)

    start_idx = 0
    pca_groups = []
    for feats in all_feats_list:
        end_idx = start_idx + feats.shape[0]
        pca_groups.append(all_pca[start_idx:end_idx])
        start_idx = end_idx

    log("Fitting GMMs...")
    gmm_groups = [fit_gmm(pca_feats, n_components=args.gmm_components, random_state=0) for pca_feats in pca_groups]

    gmm0 = gmm_groups[0]
    mu_ref, cov_ref = compute_stats_from_feats(all_feats_list[0])

    for i in range(len(gmm_groups)):
        jsd = jsd_between_gmms_mc(gmm0, gmm_groups[i], n_samples=args.mc_samples)
        kl_pq, kl_qp = kl_between_gmms_mc(gmm0, gmm_groups[i], n_samples=args.mc_samples)
        results[group_labels[i]].update(jsd=jsd, kl_pq=kl_pq, kl_qp=kl_qp)

        mu_i, cov_i = compute_stats_from_feats(all_feats_list[i])
        bd = bhattacharyya_distance(mu_ref, cov_ref, mu_i, cov_i)
        hdist, bcoeff = hellinger_from_bhattacharyya(bd)
        mah = mahalanobis_mean_distance(mu_ref, cov_ref, mu_i, cov_i)
        results[group_labels[i]].update(bhattacharyya_distance=bd, hellinger=hdist, bhattacharyya_coeff=bcoeff, mahalanobis_mean_sq=mah)

        mean_d, var_d, min_d, max_d, _ = gmm_component_center_stats(gmm0, gmm_groups[i])
        results[group_labels[i]].update(gmm_comp_center_mean=mean_d, gmm_comp_center_var=var_d, gmm_comp_center_min=min_d, gmm_comp_center_max=max_d)

        if args.mc_overlap_samples and args.mc_overlap_samples > 0:
            try:
                ovl = gaussian_overlap_mc(mu_ref, cov_ref, mu_i, cov_i, n_samples=args.mc_overlap_samples)
            except Exception as e:
                log(f"Failed to compute MC overlap for group {group_labels[i]}: {e}")
                ovl = None
            results[group_labels[i]].update(gaussian_overlap_mc=ovl)

    # Step 7: Visualization
    marginal_out_png = os.path.join('.', f'marginal_dist_pca0_{args.model_name}.png')
    log(f"Generating marginal distribution on PCA component 0: {marginal_out_png}")
    visualize_marginal_distribution(pca_groups, group_labels, pca, marginal_out_png, dim_idx=0)
    
    vis_2d_out_png = os.path.join('.', f'embedding_2d_{args.vis_method}_{args.model_name}.png')
    log(f"Generating 2D embedding visualization: {vis_2d_out_png} (method={args.vis_method})...")
    visualize_2d_embedding_multi(pca_groups, group_labels, vis_2d_out_png, method=args.vis_method)

    # Save metrics JSON
    out_path = args.out_metrics_json.replace(".json", f"_{args.model_name}.json")
    with open(out_path, 'w') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    log(f"Metrics JSON saved to {out_path}")

    # Step 8: Print summary
    log("\n---- Results Summary ----")
    for label, metrics in results.items():
        log(f"[{label}] FID: {metrics['fid']:.4f}, JSD: {metrics['jsd']:.4e}, KL(ref||q): {metrics['kl_pq']:.4e}, KL(q||ref): {metrics['kl_qp']:.4e}")
        log(f"    Bhattacharyya Distance: {metrics['bhattacharyya_distance']:.4e}, Hellinger: {metrics['hellinger']:.4e}, Mahalanobis Mean Sq: {metrics['mahalanobis_mean_sq']:.4e}")
        log(f"    GMM Center Dist (mean/var/min/max): {metrics['gmm_comp_center_mean']:.4e} / {metrics['gmm_comp_center_var']:.4e} / {metrics['gmm_comp_center_min']:.4e} / {metrics['gmm_comp_center_max']:.4e}")
        if 'gaussian_overlap_mc' in metrics:
            log(f"    Gaussian Overlap (MC): {metrics['gaussian_overlap_mc']}")

if __name__ == '__main__':
    main()