"""
A tool for calculating Signal-to-Noise Ratio (SNR) of video features for given groups of videos.
"""
import os, sys, argparse, time, json
from collections import OrderedDict

import numpy as np
import cv2
import torch
import torch.nn as nn

# ---------- Utility Functions ----------
def log(*args, **kwargs):
    """Prints a log message and flushes the buffer immediately."""
    print(*args, **kwargs, flush=True)

def list_mp4s(dirpath):
    """Lists all .mp4 files in a directory."""
    return sorted([f for f in os.listdir(dirpath) if f.lower().endswith(".mp4")])

def read_first_n_frames_normalized(video_path, n=300, resize_hw=(299,299), norm_mean=None, norm_std=None):
    """
    Reads the first N frames from a video, resizes, and normalizes them.

    Args:
        video_path (str): Path to the video file.
        n (int): Number of frames to read.
        resize_hw (tuple): Target height and width for resizing.
        norm_mean (list): Mean for normalization.
        norm_std (list): Standard deviation for normalization.

    Returns:
        np.ndarray: A numpy array of shape (n, 3, H, W) with dtype float32.
    """
    if norm_mean is None:
        norm_mean = [0.485, 0.456, 0.406]
    if norm_std is None:
        norm_std = [0.229, 0.224, 0.225]

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video: {video_path}")
    frames = []
    read_count = 0
    for i in range(n):
        ret, frm = cap.read()
        if not ret:
            cap.release()
            raise RuntimeError(f"Video {video_path} has fewer than {n} frames (found {read_count}).")
        frm_rgb = cv2.cvtColor(frm, cv2.COLOR_BGR2RGB)
        frm_resized = cv2.resize(frm_rgb, resize_hw, interpolation=cv2.INTER_LINEAR)
        frm_resized = frm_resized.astype(np.float32) / 255.0
        frames.append(frm_resized)
        read_count += 1
    cap.release()
    arr = np.stack(frames, axis=0)
    arr = arr.transpose(0,3,1,2)
    mean = np.array(norm_mean, dtype=np.float32).reshape(1,3,1,1)
    std  = np.array(norm_std, dtype=np.float32).reshape(1,3,1,1)
    arr = (arr - mean) / std
    return arr

# ---------- Model Loading & Feature Extraction ----------
class ClipWrapper(nn.Module):
    """A wrapper to make the CLIP model interface consistent with other models."""
    def __init__(self, clip_model):
        super().__init__()
        self.clip_model = clip_model
    def forward(self, x):
        return self.clip_model.encode_image(x)

def get_model_and_config(model_name, device):
    """
    Loads a pretrained model and returns it along with its configuration.
    
    Returns:
        tuple: (model, config_dict)
    """
    log(f"Loading model: {model_name}...")
    if model_name == 'inceptionv3':
        import torchvision
        try:
            Weights = getattr(torchvision.models, "Inception_V3_Weights", None)
            weights = Weights.DEFAULT if Weights else None
            model = torchvision.models.inception_v3(weights=weights, aux_logits=False)
        except Exception:
            model = torchvision.models.inception_v3(pretrained=True, aux_logits=False)
        
        model.fc = nn.Identity()
        model.to(device)
        model.eval()
        
        config = {
            "input_size": 299,
            "norm_mean": [0.485, 0.456, 0.406],
            "norm_std": [0.229, 0.224, 0.225],
            "feature_dim": 2048,
        }
        return model, config

    elif model_name.startswith('clip_'):
        try:
            import clip
        except ImportError:
            raise RuntimeError("CLIP models require `clip-by-openai`. Run: pip install git+https://github.com/openai/CLIP.git")
        
        clip_model_name_map = {
            'clip_vit_b_32': 'ViT-B/32',
            'clip_vit_b_16': 'ViT-B/16',
            'clip_vit_l_14': 'ViT-L/14',
            'clip_rn50': 'RN50',
            'clip_rn101': 'RN101',
            'clip_rn50x4': 'RN50x4',
            'clip_rn50x16': 'RN50x16',
            'clip_rn50x64': 'RN50x64',
            'clip_vit_l_14_336': 'ViT-L/14@336px',
        }
        if model_name not in clip_model_name_map:
            raise ValueError(f"Unsupported CLIP model name: {model_name}")
        clip_model_name = clip_model_name_map[model_name]
        model, preprocess = clip.load(clip_model_name, device=device)
        
        input_size = preprocess.transforms[0].size
        norm_mean = preprocess.transforms[-1].mean
        norm_std = preprocess.transforms[-1].std
        feature_dim = model.visual.output_dim
        
        wrapped_model = ClipWrapper(model)
        wrapped_model.to(device)
        wrapped_model.eval()
        
        config = {
            "input_size": input_size,
            "norm_mean": norm_mean,
            "norm_std": norm_std,
            "feature_dim": feature_dim,
        }
        return wrapped_model, config

    elif model_name.startswith('dinov2_'):
        try:
            import timm
        except ImportError:
            raise RuntimeError("DINOv2 models require `timm`. Run: pip install timm")
            
        timm_model_name_map = {
            'dinov2_vits14': 'vit_small_patch14_dinov2.lvd142m',
            'dinov2_vitb14': 'vit_base_patch14_dinov2.lvd142m',
            'dinov2_vitl14': 'vit_large_patch14_dinov2.lvd142m',
        }
        if model_name not in timm_model_name_map:
            raise ValueError(f"Unknown DINOv2 model: {model_name}. Choices: {list(timm_model_name_map.keys())}")
    
        timm_model_name = timm_model_name_map[model_name]
    
        model = timm.create_model(timm_model_name, pretrained=False, num_classes=0)
    
        model.to(device)
        model.eval()
        
        data_config = model.default_cfg
        config = {
            "input_size": data_config['input_size'][-1],
            "norm_mean": data_config['mean'],
            "norm_std": data_config['std'],
            "feature_dim": model.embed_dim,
        }
        return model, config
        
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

def extract_feats_from_frames_arr(model, frames_arr, device, batch_size=64):
    """
    Extracts features from an array of frames.

    Args:
        model: The feature extraction model.
        frames_arr (np.ndarray): (T, 3, H, W) float32 numpy array.
        device: The device to run inference on.
        batch_size (int): Batch size for inference.

    Returns:
        np.ndarray: (T, D) numpy feature array.
    """
    T = frames_arr.shape[0]
    out_list = []
    with torch.no_grad():
        for i in range(0, T, batch_size):
            batch = frames_arr[i:i+batch_size]
            batch_t = torch.from_numpy(batch).to(device=device, dtype=torch.float32)
            
            feats = model(batch_t)
            
            if isinstance(feats, tuple):
                feats = feats[0]
            
            if isinstance(model, ClipWrapper):
                feats = feats / feats.norm(dim=-1, keepdim=True)

            feats_np = feats.cpu().numpy()
            out_list.append(feats_np)
            
    if len(out_list) == 0:
        return np.zeros((0,0), dtype=np.float32)
        
    return np.concatenate(out_list, axis=0)


# ---------- SNR Calculation ----------
def compute_snr_from_features(video_feats_list):
    """
    Computes Signal-to-Noise Ratio (SNR) from a list of video features.

    Signal is defined as the mean variance of mean features across videos.
    Noise is defined as the mean variance of features within each video.

    Args:
        video_feats_list (list of np.ndarray): A list where each element is a (frames, dim)
                                               numpy array of features for one video.

    Returns:
        tuple: (snr, snr_db, signal_power, noise_power)
    """
    if len(video_feats_list) < 2:
        log("Warning: SNR calculation requires at least 2 videos. Returning 0.")
        return 0.0, -np.inf, 0.0, 0.0

    # (num_videos, feat_dim)
    mean_feats_per_video = np.array([np.mean(feats, axis=0) for feats in video_feats_list])

    # Signal power: mean variance of the mean features across videos
    signal_power = np.mean(np.var(mean_feats_per_video, axis=0))

    # Noise power: mean of the per-video feature variances
    noise_power_per_video = []
    for feats in video_feats_list:
        if feats.shape[0] > 1:
            noise_power_per_video.append(np.mean(np.var(feats, axis=0)))

    if not noise_power_per_video:
        log("Warning: No videos with more than 1 frame found for noise calculation.")
        return np.nan, np.nan, float(signal_power), 0.0

    noise_power = np.mean(noise_power_per_video)

    if noise_power < 1e-9:
        snr = np.inf if signal_power > 0 else 0.0
        snr_db = np.inf if signal_power > 0 else -np.inf
    else:
        snr = signal_power / noise_power
        snr_db = 10 * np.log10(snr)
        
    return float(snr), float(snr_db), float(signal_power), float(noise_power)


# ---------- Main Workflow ----------
def parse_args():
    """Parses command-line arguments."""
    p = argparse.ArgumentParser(description="Video feature SNR calculation tool.")
    p.add_argument("--group_dirs", type=str, nargs='+', action='append', required=True,
                   help="List of video directories for a group (can be specified multiple times for one group).")
    p.add_argument("--group_labels", type=str, nargs='+', required=True,
                   help="Label for each group (must match the order of --group_dirs).")
    p.add_argument("--group_max_videos", type=int, nargs='*', default=None,
                   help="Max number of videos to process per group (optional, matches --group_dirs order).")
    p.add_argument("--model_name", type=str, default="inceptionv3", 
                   choices=["inceptionv3", "clip_vit_b_32", "dinov2_vitb14"],
                   help="Model to use for feature extraction.")
    p.add_argument("--dinov2_model_path", type=str, default=None, help="Path to local DINOv2 model checkpoint.")
    p.add_argument("--device", type=str, default="cuda:0", help="Device to use: cpu or cuda:0.")
    p.add_argument("--batch_size", type=int, default=32, help="Batch size for feature extraction.")
    p.add_argument("--save_feats_outdir", type=str, default=None, help="If provided, save extracted features to this directory.")
    p.add_argument("--out_metrics_json", type=str, default="snr_results.json", help="Path to save the output metrics JSON file.")
    return p.parse_args()


def collect_mp4s_from_dirs_with_limit(dirs, max_videos_list=None):
    """Collects all .mp4 file paths from a list of directories with optional limits."""
    all_mp4s = []
    if max_videos_list is None:
        max_videos_list = [None] * len(dirs)
    if len(max_videos_list) != len(dirs):
        raise ValueError(f"Number of max_videos ({len(max_videos_list)}) must match number of dirs ({len(dirs)}).")
    for i, dirpath in enumerate(dirs):
        if not os.path.exists(dirpath):
            log(f"Warning: Directory {dirpath} does not exist, skipping.")
            continue
        mp4s = [(dirpath, f) for f in list_mp4s(dirpath)]
        max_count = max_videos_list[i]
        if max_count is not None and len(mp4s) > max_count:
            log(f"Limiting directory {dirpath} to {max_count} videos (from {len(mp4s)}).")
            mp4s = mp4s[:max_count]
        else:
            log(f"Using all {len(mp4s)} videos from {dirpath}.")
        all_mp4s.extend(mp4s)
    return all_mp4s

def main():
    args = parse_args()
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    log(f"Using device: {device}")

    group_dirs = args.group_dirs
    group_labels = args.group_labels
    group_max_videos = args.group_max_videos or [None] * len(group_dirs)

    if len(group_dirs) != len(group_labels):
        raise ValueError("Number of --group_dirs must match number of --group_labels.")
    if len(group_max_videos) != len(group_dirs):
        raise ValueError("Number of --group_max_videos must match number of --group_dirs.")

    # Step 1: Load model and get its configuration
    model, model_config = get_model_and_config(args.model_name, device)
    if args.model_name.startswith('dinov2_') and args.dinov2_model_path:
        log(f"Loading local DINOv2 weights from: {args.dinov2_model_path}")
        state_dict = torch.load(args.dinov2_model_path, map_location='cpu')
        model.load_state_dict(state_dict, strict=True)

    input_size = model_config["input_size"]
    norm_mean = model_config["norm_mean"]
    norm_std = model_config["norm_std"]
    log(f"Model config: input_size={input_size}x{input_size}, feature_dim={model_config['feature_dim']}")

    # Step 2: Collect videos from each group
    all_groups_mp4s = []
    for i, dirs in enumerate(group_dirs):
        mp4s = collect_mp4s_from_dirs_with_limit(dirs, [group_max_videos[i]] * len(dirs))
        all_groups_mp4s.append(mp4s)
        log(f"Group '{group_labels[i]}': Found {len(mp4s)} videos.")

    # Step 3: Extract features for each video in each group
    all_groups_video_feats = []
    for i, group_mp4s in enumerate(all_groups_mp4s):
        video_feats_list_for_group = []
        for dirpath, fname in group_mp4s:
            p = os.path.join(dirpath, fname)
            log(f"[{group_labels[i]}] Reading frames from {p}...")
            try:
                arr = read_first_n_frames_normalized(p, n=210, 
                                                     resize_hw=(input_size, input_size),
                                                     norm_mean=norm_mean, norm_std=norm_std)
                feats = extract_feats_from_frames_arr(model, arr, device, batch_size=args.batch_size)
                log(f"[{group_labels[i]}] {fname}: Extracted {feats.shape[0]} frames -> features {feats.shape}")
                if feats.shape[0] > 0:
                    video_feats_list_for_group.append(feats)
                del arr
            except Exception as e:
                log(f"[{group_labels[i]}] Error processing {p}: {e}")
                continue
        
        if len(video_feats_list_for_group) == 0:
            log(f"Warning: No videos were successfully processed for group '{group_labels[i]}'. Skipping SNR calculation for this group.")
        all_groups_video_feats.append(video_feats_list_for_group)

    # Step 4: Save features if requested
    if args.save_feats_outdir:
        os.makedirs(args.save_feats_outdir, exist_ok=True)
        for i, group_video_feats in enumerate(all_groups_video_feats):
            # To save, we can concatenate them, as this is just for offline analysis
            if len(group_video_feats) > 0:
                all_feats_for_group = np.vstack(group_video_feats)
                np.save(os.path.join(args.save_feats_outdir, f"feats_{group_labels[i]}_{args.model_name}.npy"), all_feats_for_group)
        log(f"Features saved to {args.save_feats_outdir}")

    # Step 5: Compute SNR for each group
    results = OrderedDict()
    for i, group_feats in enumerate(all_groups_video_feats):
        label = group_labels[i]
        log(f"Computing SNR for group '{label}'...")
        snr, snr_db, signal, noise = compute_snr_from_features(group_feats)
        results[label] = {
            "snr": snr,
            "snr_db": snr_db,
            "signal_power": signal,
            "noise_power": noise,
            "video_count": len(group_feats)
        }

    # Save metrics JSON
    out_path = args.out_metrics_json.replace(".json", f"_{args.model_name}.json")
    with open(out_path, 'w') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    log(f"Metrics JSON saved to {out_path}")

    # Step 6: Print summary
    log("\n---- SNR Results Summary ----")
    for label, metrics in results.items():
        log(f"[{label}] ({metrics['video_count']} videos)")
        log(f"  Signal Power : {metrics['signal_power']:.4e}")
        log(f"  Noise Power  : {metrics['noise_power']:.4e}")
        log(f"  SNR          : {metrics['snr']:.4f}")
        log(f"  SNR (dB)     : {metrics['snr_db']:.4f}")

if __name__ == '__main__':
    main()
