#!/usr/bin/env python
"""
Enhanced Video Quality Assessment Tool - Enhanced Version
Extended Features:
1. Individual evaluation for each human mask region
2. Optional saving of visualization images with applied masks (Error Map same size as GT)
3. Advanced evaluation metrics: DISTS, CLIP-FID/CLIPScore, ST-SSIM/GMSD-Temporal
"""

import os
import cv2
import numpy as np
import torch
import lpips
from torchvision import models, transforms
import torch.nn.functional as F
from tqdm import tqdm
import argparse
import json
from tabulate import tabulate
from glob import glob
import math
from scipy import linalg
from torchvision.models import inception_v3
import re
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import time
from functools import partial
import psutil
import pandas as pd
import sys
import traceback
from tabulate import tabulate
from PIL import Image
import torch.nn as nn
from collections import OrderedDict
import multiprocessing
import torch.multiprocessing as mp


from torch.cuda.amp import autocast


try:
    import clip
    CLIP_AVAILABLE = True
except ImportError:
    print("Warning: CLIP module not installed, CLIP-related metrics will be disabled. Please install with 'pip install clip'.")
    CLIP_AVAILABLE = False


try:
    from DISTS_pytorch import DISTS
    DISTS_AVAILABLE = True
except ImportError:
    print("Warning: DISTS module not installed, using simplified DISTS. For full version please install from 'https://github.com/dingkeyan93/DISTS'.")
    DISTS_AVAILABLE = False

class DISTSMetric(nn.Module):
    """Simplified implementation of DISTS (Deep Image Structure and Texture Similarity)"""
    def __init__(self, device):
        super(DISTSMetric, self).__init__()
        self.device = device
        

        if DISTS_AVAILABLE:
            self.dists_fn = DISTS().to(device)
        else:
            # Use VGG as backup
            vgg = models.vgg16(pretrained=True).to(device)
            self.vgg_layers = vgg.features
            self.layer_indices = [3, 8, 15, 22, 29]
            
            # Freeze parameters
            for param in self.parameters():
                param.requires_grad = False
    
    def forward(self, img1, img2):
        """Calculate DISTS score"""
        if DISTS_AVAILABLE:
            with torch.no_grad():
                return self.dists_fn(img1, img2)
        else:
            # Simplified implementation using VGG features
            with torch.no_grad():
                feat1 = self.extract_features(img1)
                feat2 = self.extract_features(img2)
                
                # Calculate similarity score
                score = self.compute_similarity(feat1, feat2)
                
            return score
    
    def extract_features(self, x):
        """Extract multi-layer features from VGG"""
        features = []
        for i, layer in enumerate(self.vgg_layers):
            x = layer(x)
            if i in self.layer_indices:
                features.append(x)
        return features
    
    def compute_similarity(self, feat1, feat2):
        """Calculate structural and texture similarity between features"""
        sim_score = 0.0
        
        for f1, f2 in zip(feat1, feat2):
            # Calculate structural similarity
            mu1 = torch.mean(f1, dim=[2, 3], keepdim=True)
            mu2 = torch.mean(f2, dim=[2, 3], keepdim=True)
            sig1 = torch.var(f1, dim=[2, 3], keepdim=True)
            sig2 = torch.var(f2, dim=[2, 3], keepdim=True)
            
            # Normalize features
            f1_norm = (f1 - mu1) / (torch.sqrt(sig1) + 1e-8)
            f2_norm = (f2 - mu2) / (torch.sqrt(sig2) + 1e-8)
            
            # Structural similarity
            corr = torch.mean(f1_norm * f2_norm, dim=[2, 3])
            struct_sim = torch.mean(corr)
            
            # Texture similarity using Gram matrix
            b, c, h, w = f1.shape
            f1_flat = f1.view(b, c, -1)
            f2_flat = f2.view(b, c, -1)
            
            gram1 = torch.bmm(f1_flat, f1_flat.transpose(1, 2)) / (h * w)
            gram2 = torch.bmm(f2_flat, f2_flat.transpose(1, 2)) / (h * w)
            
            text_sim = F.mse_loss(gram1, gram2)
            
            # Combine structural and texture similarity
            sim_score += 0.5 * struct_sim + 0.5 * (1 - text_sim)
            
        return sim_score / len(feat1)

class CLIPScorer:
    """Calculate image similarity and FID using CLIP model"""
    def __init__(self, device):
        self.device = device
        
        if CLIP_AVAILABLE:
            # Try to load local CLIP model first
            local_clip_path = "path/to/models/ViT-B-32.pt"
            
            # Load CLIP model
            try:
                if os.path.exists(local_clip_path):
                    print(f"Loading CLIP model from local path: {local_clip_path}")
                    # Load from local path
                    self.model, self.preprocess = clip.load(local_clip_path, device=device)
                    # Enable eval mode
                else:
                    print(f"Local CLIP model not found, using standard loading")
                    self.model, self.preprocess = clip.load("ViT-B/32", device=device)
            except Exception as e:
                print(f"Failed to load local CLIP model: {e}, trying standard loading")
                self.model, self.preprocess = clip.load("ViT-B/32", device=device)
        else:
            self.model = None
            self.preprocess = None
    
    def get_clip_features(self, images):
        """Extract CLIP features"""
        if not CLIP_AVAILABLE:
            print("Warning: CLIP not installed, cannot extract features")
            return torch.zeros((len(images), 512), device=self.device)
        
        features = []
        for img in images:
            # Convert BGR to RGB
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_pil = Image.fromarray(img_rgb)
            # Preprocess image
            img_tensor = self.preprocess(img_pil).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                feature = self.model.encode_image(img_tensor)
                features.append(feature)
        
        return torch.cat(features)
    
    def calculate_clip_fid(self, real_features, fake_features, eps=1e-6):
        """Calculate FID based on CLIP features"""
        if not CLIP_AVAILABLE:
            print("Warning: CLIP not installed, cannot calculate CLIP-FID")
            return 0.0
            
        # Convert to numpy
        real_np = real_features.cpu().numpy()
        fake_np = fake_features.cpu().numpy()
        
        mu1, sigma1 = np.mean(real_np, axis=0), np.cov(real_np, rowvar=False)
        mu2, sigma2 = np.mean(fake_np, axis=0), np.cov(fake_np, rowvar=False)
        
        diff = mu1 - mu2
        
        # Add regularization
        sigma1 = sigma1 + np.eye(sigma1.shape[0]) * eps
        sigma2 = sigma2 + np.eye(sigma2.shape[0]) * eps
        
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if np.iscomplexobj(covmean):
            covmean = covmean.real
            
        tr_covmean = np.trace(covmean)
        fid = np.sum(np.square(diff)) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
        
        return fid
    
    def calculate_clip_score(self, image1, image2):
        """Calculate CLIP similarity score between two images"""
        if not CLIP_AVAILABLE:
            print("Warning: CLIP not installed, cannot calculate CLIPScore")
            return 0.0
            
        # Extract features
        feat1 = self.get_clip_features([image1])
        feat2 = self.get_clip_features([image2])
        
        # Normalize features
        feat1 = feat1 / feat1.norm(dim=-1, keepdim=True)
        feat2 = feat2 / feat2.norm(dim=-1, keepdim=True)
        
        # Calculate cosine similarity
        similarity = torch.nn.functional.cosine_similarity(feat1, feat2)
        
        # Convert to [0, 1] range
        score = (similarity + 1) / 2
        
        return score.item()

class SpatioTemporalMetrics:
    """Spatio-temporal video quality evaluation metrics"""
    def __init__(self, device):
        self.device = device
    
    def calculate_st_ssim(self, gt_frames, pred_frames, window_size=3):
        """Calculate spatio-temporal SSIM"""
        from skimage.metrics import structural_similarity as ssim
        
        # Handle case with insufficient frames
        if len(gt_frames) < window_size or len(pred_frames) < window_size:
            ssim_values = []
            for i in range(min(len(gt_frames), len(pred_frames))):
                # Calculate per-channel SSIM
                ssim_vals = []
                for c in range(3):
                    ssim_val = ssim(gt_frames[i][:,:,c], pred_frames[i][:,:,c], 
                                     data_range=255, win_size=11, gaussian_weights=True)
                    ssim_vals.append(ssim_val)
                ssim_values.append(np.mean(ssim_vals))
            return np.mean(ssim_values)
        
        # Calculate 3D SSIM with temporal window
        st_ssim_values = []
        
        for i in range(len(gt_frames) - window_size + 1):
            # Create temporal volume
            gt_vol = np.stack(gt_frames[i:i+window_size], axis=0)
            pred_vol = np.stack(pred_frames[i:i+window_size], axis=0)
            
            # Calculate 3D SSIM for each channel
            ssim_3d = []
            for c in range(3):
                # Extract channel data
                gt_c = gt_vol[:,:,:,c]
                pred_c = pred_vol[:,:,:,c]
                
                # Calculate statistics
                gt_mean = np.mean(gt_c)
                pred_mean = np.mean(pred_c)
                gt_std = np.std(gt_c)
                pred_std = np.std(pred_c)
                
                # Normalize and calculate correlation
                gt_norm = (gt_c - gt_mean) / (gt_std + 1e-8)
                pred_norm = (pred_c - pred_mean) / (pred_std + 1e-8)
                cross_corr = np.mean(gt_norm * pred_norm)
                
                # SSIM formula
                c1 = (0.01 * 255)**2
                c2 = (0.03 * 255)**2
                ssim_c = ((2 * gt_mean * pred_mean + c1) * (2 * cross_corr * gt_std * pred_std + c2)) / \
                         ((gt_mean**2 + pred_mean**2 + c1) * (gt_std**2 + pred_std**2 + c2))
                
                ssim_3d.append(ssim_c)
            
            st_ssim_values.append(np.mean(ssim_3d))
            
        return np.mean(st_ssim_values)
    
    def calculate_gmsd_temporal(self, gt_frames, pred_frames):
        """Calculate temporal gradient magnitude similarity deviation"""
        if len(gt_frames) < 2 or len(pred_frames) < 2:
            return 0.0
        
        # Calculate temporal gradients
        gt_temp_grad = []
        pred_temp_grad = []
        
        for i in range(1, min(len(gt_frames), len(pred_frames))):
            # Calculate temporal difference
            gt_diff = gt_frames[i].astype(np.float32) - gt_frames[i-1].astype(np.float32)
            pred_diff = pred_frames[i].astype(np.float32) - pred_frames[i-1].astype(np.float32)
            
            # Calculate gradient magnitude
            gt_mag = np.sqrt(np.sum(gt_diff**2, axis=2))
            pred_mag = np.sqrt(np.sum(pred_diff**2, axis=2))
            
            # Normalize gradients
            gt_norm = gt_mag / (np.mean(gt_mag) + 1e-8)
            pred_norm = pred_mag / (np.mean(pred_mag) + 1e-8)
            
            # Calculate gradient magnitude similarity
            gms = (2 * gt_norm * pred_norm + 1e-8) / (gt_norm**2 + pred_norm**2 + 1e-8)
            
            # Calculate GMSD for this frame pair
            gmsd = np.std(gms)
            
            # Store for averaging
            gt_temp_grad.append(gt_mag)
            pred_temp_grad.append(pred_mag)
            
        # Calculate overall GMSD across all temporal pairs
        total_gmsd = 0.0
        count = 0
        
        for i in range(len(gt_temp_grad)):
            gt_grad = gt_temp_grad[i]
            pred_grad = pred_temp_grad[i]
            
            # Normalize gradients
            gt_norm = gt_grad / (np.mean(gt_grad) + 1e-8)
            pred_norm = pred_grad / (np.mean(pred_grad) + 1e-8)
            
            # Calculate GMS
            gms = (2 * gt_norm * pred_norm + 1e-8) / (gt_norm**2 + pred_norm**2 + 1e-8)
            
            # Calculate GMSD
            gmsd = np.std(gms)
            total_gmsd += gmsd
            count += 1
            
        if count > 0:
            return total_gmsd / count
        else:
            return 0.0

class EnhancedVideoQualityEvaluator:
    def __init__(self, device, gpu_id=0, verbose_paths=False, weights_dir=None, save_visuals=False,
                batch_size=32, use_mixed_precision=False, prefetch_size=64, io_threads=24, use_dali=False):
        self.device = device
        self.gpu_id = gpu_id
        self.verbose_paths = verbose_paths
        self.save_visuals = save_visuals
        self.batch_size = batch_size
        self.use_mixed_precision = use_mixed_precision
        self.prefetch_size = prefetch_size
        self.io_threads = io_threads
        self.use_dali = use_dali
        
        # Set weights directory
        if weights_dir and os.path.exists(weights_dir):
            self.weights_dir = weights_dir
        else:
            # Use default weights directory
            default_weights_dir = "path/to/models"
            if os.path.exists(default_weights_dir):
                self.weights_dir = default_weights_dir
            else:
                self.weights_dir = None
        
        print(f"[GPU {self.gpu_id}] Initializing models...")
        if self.weights_dir:
            print(f"[GPU {self.gpu_id}] Using local weights directory: {self.weights_dir}")
        else:
            print(f"[GPU {self.gpu_id}] Using online weights")
        
        # Load LPIPS model
        self.lpips_fn = self._load_lpips_model()
        # Load I3D model
        self.i3d_model = self._load_i3d_model()
        # Load Inception model
        self.inception_model = self._load_inception_model()
        
        # Initialize other metrics
        self.dists_model = DISTSMetric(device)
        self.clip_scorer = CLIPScorer(device)
        self.st_metrics = SpatioTemporalMetrics(device)
        
        print(f"[GPU {self.gpu_id}] Model initialization completed")
    
    def _load_lpips_model(self):
        """Load LPIPS model (prioritize local weights and prevent duplicate downloads)"""
        try:
            # Try to load from local weights
            if self.weights_dir:
                lpips_weight_path = os.path.join(self.weights_dir, 'vgg.pth')
                if os.path.exists(lpips_weight_path):
                    print(f"[GPU {self.gpu_id}] Loading LPIPS weights from local: {lpips_weight_path}")
                    
                    # Load with custom path
                    with torch.no_grad():
                        # Temporarily set TORCH_HOME
                        original_torch_home = os.environ.get('TORCH_HOME')
                        os.environ['TORCH_HOME'] = self.weights_dir
                        
                        # Create model without pretrained weights
                        lpips_fn = lpips.LPIPS(net='vgg', pretrained=False)
                        
                        # Restore original TORCH_HOME
                        if original_torch_home:
                            os.environ['TORCH_HOME'] = original_torch_home
                        else:
                            os.environ.pop('TORCH_HOME', None)
                    
                    # Load state dict
                    lpips_fn.load_state_dict(torch.load(lpips_weight_path, map_location=self.device))
                    return lpips_fn.to(self.device)
            
            # Use pretrained model
            print(f"[GPU {self.gpu_id}] Using pretrained LPIPS model")
            
            # Check for cached VGG weights
            cache_dir = os.path.expanduser('path/to/models/')
            vgg_cache = os.path.join(cache_dir, 'vgg16-397923af.pth')
            
            if os.path.exists(vgg_cache):
                print(f"[GPU {self.gpu_id}] Using cached VGG weights: {vgg_cache}")
            
            return lpips.LPIPS(net='vgg').to(self.device)
            
        except Exception as e:
            print(f"[GPU {self.gpu_id}] LPIPS weights loading failed: {e}")
            # Fallback to online weights
            return lpips.LPIPS(net='vgg').to(self.device)
            
    def _load_i3d_model(self):
        """Load I3D model (prioritize local weights)"""
        try:
            # Try to load from local weights
            if self.weights_dir:
                i3d_weight_path = os.path.join(self.weights_dir, 'r3d_18_weights.pth')
                if os.path.exists(i3d_weight_path):
                    print(f"[GPU {self.gpu_id}] Loading I3D weights from local: {i3d_weight_path}")
                    model = models.video.r3d_18(pretrained=False)
                    model.load_state_dict(torch.load(i3d_weight_path, map_location=self.device))
                    model = torch.nn.Sequential(*list(model.children())[:-1])
                    return model.to(self.device).eval()
            
            # Use pretrained model
            print(f"[GPU {self.gpu_id}] Using pretrained I3D model")
            model = models.video.r3d_18(pretrained=True)
            model = torch.nn.Sequential(*list(model.children())[:-1])
            return model.to(self.device).eval()
            
        except Exception as e:
            print(f"[GPU {self.gpu_id}] I3D weights loading failed: {e}")
            # Fallback to online weights
            model = models.video.r3d_18(pretrained=True)
            model = torch.nn.Sequential(*list(model.children())[:-1])
            return model.to(self.device).eval()
    
    def _load_inception_model(self):
        """Load InceptionV3 model (prioritize local weights)"""
        try:
            # Try to load from local weights
            if self.weights_dir:
                inception_weight_path = os.path.join(self.weights_dir, 'inception_v3_weights.pth')
                if os.path.exists(inception_weight_path):
                    print(f"[GPU {self.gpu_id}] Loading InceptionV3 weights from local: {inception_weight_path}")
                    model = inception_v3(pretrained=False, transform_input=False)
                    model.load_state_dict(torch.load(inception_weight_path, map_location=self.device))
                    model.fc = torch.nn.Identity()
                    return model.to(self.device).eval()
            
            # Use pretrained model
            print(f"[GPU {self.gpu_id}] Using pretrained InceptionV3 model")
            model = inception_v3(pretrained=True, transform_input=False)
            model.fc = torch.nn.Identity()
            return model.to(self.device).eval()
            
        except Exception as e:
            print(f"[GPU {self.gpu_id}] InceptionV3 weights loading failed: {e}")
            # Fallback to online weights
            model = inception_v3(pretrained=True, transform_input=False)
            model.fc = torch.nn.Identity()
            return model.to(self.device).eval()
    
    def natural_sort_key(self, s):
        """Natural sorting key function"""
        return [int(text) if text.isdigit() else text.lower() 
                for text in re.split('([0-9]+)', s)]
    
    def read_image_sequence(self, image_folder):
        """Read image sequence from folder and return frame list"""
        if not os.path.exists(image_folder):
            raise ValueError(f"Folder does not exist: {image_folder}")
            
        # Try different image patterns
        image_patterns = ['frame_*.png', 'frame_*.jpg', '*.png', '*.jpg']
        image_files = []
        
        for pattern in image_patterns:
            files = glob(os.path.join(image_folder, pattern))
            if files:
                image_files.extend(files)
                break
        
        # Sort files naturally
        image_files.sort(key=self.natural_sort_key)
        
        if not image_files:
            raise ValueError(f"No image files found in folder {image_folder}")
            
        frames = []
        frame_paths = []
        for img_path in image_files:
            img = cv2.imread(img_path)
            if img is None:
                print(f"Warning: Cannot read image {img_path}, skipping")
                continue
            frames.append(img)
            frame_paths.append(img_path)
            
        print(f"[GPU {self.gpu_id}] Read {len(frames)} frames from {image_folder}")
        return np.array(frames), frame_paths
    
    def read_mask_sequence(self, video_path, mask_type="masks"):
        """Read human mask sequence and return frame list for each person"""
        # Build mask directory path
        mask_dir = os.path.join(video_path, mask_type)
        
        if not os.path.exists(mask_dir):
            print(f"Warning: Mask directory not found {mask_dir}")
            return None, None
        
        person_dirs = []
        for item in os.listdir(mask_dir):
            if item.startswith('person_') and os.path.isdir(os.path.join(mask_dir, item)):
                person_dirs.append(os.path.join(mask_dir, item))
                
        if not person_dirs:
            print(f"Warning: No person folders found in {mask_dir}")
            return None, None
        
        # Read mask frames for each person
        persons_mask_frames = {}
        persons_mask_paths = {}
        
        for person_dir in person_dirs:
            person_id = os.path.basename(person_dir)
            mask_patterns = ['frame_*.png', 'frame_*.jpg', '*.png', '*.jpg']
            for pattern in mask_patterns:
                mask_files = glob(os.path.join(person_dir, pattern))
                if mask_files:
                    mask_files.sort(key=self.natural_sort_key)
                    
                    # Read mask images
                    mask_frames = []
                    for mask_path in mask_files:
                        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                        if mask is not None:
                            # Binarize mask
                            _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
                            mask_frames.append(mask)
                    
                    if mask_frames:
                        persons_mask_frames[person_id] = np.array(mask_frames)
                        persons_mask_paths[person_id] = mask_files[:len(mask_frames)]
                    break
        
        if not persons_mask_frames:
            print(f"Warning: Failed to load any mask frames")
            return None, None
        
        print(f"[GPU {self.gpu_id}] Read mask frames for {len(persons_mask_frames)} persons from {video_path}/{mask_type}")
        return persons_mask_frames, persons_mask_paths
    
    def prefetch_data(self, paths_list, num_workers=16):
        """Prefetch images using multiple workers into memory"""
        from concurrent.futures import ThreadPoolExecutor
        
        data_cache = {}
        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            futures = {executor.submit(cv2.imread, path): path for path in paths_list}
            for future in tqdm(as_completed(futures), total=len(futures)):
                path = futures[future]
                try:
                    data_cache[path] = future.result()
                except Exception as e:
                    print(f"Error loading {path}: {e}")
        return data_cache
        
    def apply_mask_to_frame(self, frame, mask):
        """Apply mask to frame"""
        if mask is None:
            return frame
        
        # Convert mask to grayscale if needed
        if len(mask.shape) == 3 and mask.shape[2] == 3:
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
        
        _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
        
        # Create 3-channel mask
        mask_3ch = np.stack([mask_binary, mask_binary, mask_binary], axis=2) / 255.0
        
        # Apply mask
        masked_frame = frame * mask_3ch
        
        return masked_frame.astype(np.uint8)
    
    def calculate_l1_loss(self, original, compressed, mask=None, path1="", path2="", 
                          save_dir=None, frame_idx=None, person_id=None):
        """Calculate single frame L1 loss, optionally with mask"""
        # Store original images for visualization
        original_viz = original.copy()
        compressed_viz = compressed.copy()
        
        # Apply mask if provided
        if mask is not None:
            original = self.apply_mask_to_frame(original, mask)
            compressed = self.apply_mask_to_frame(compressed, mask)
        
        # Save visualizations if requested
        if self.save_visuals and save_dir and frame_idx is not None and person_id is not None:
            os.makedirs(save_dir, exist_ok=True)
            
            # Create comparison visualization
            h, w = original.shape[:2]
            
            # Calculate error map
            diff = np.abs(original.astype(np.float32) - compressed.astype(np.float32)).mean(axis=2)
            diff_normalized = (diff / diff.max() * 255).astype(np.uint8) if diff.max() > 0 else np.zeros_like(diff, dtype=np.uint8)
            diff_color = cv2.applyColorMap(diff_normalized, cv2.COLORMAP_JET)
            
            # Create side-by-side comparison
            comparison_img = np.hstack([original, compressed])
            
            # Add separator line
            cv2.line(comparison_img, (w, 0), (w, h), (0, 0, 255), 2)
            
            # Add labels
            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(comparison_img, "GT", (10, 30), font, 1, (0, 255, 0), 2)
            cv2.putText(comparison_img, "Pred", (w + 10, 30), font, 1, (0, 255, 0), 2)
            
            # Add error map
            error_map_viz = np.hstack([diff_color, np.zeros((h, w, 3), dtype=np.uint8)])
            
            # Combine all visualizations
            final_viz = np.vstack([comparison_img, error_map_viz])
            
            # Add error map label
            cv2.putText(final_viz, "Error Map (JET)", (10, h + 30), font, 1, (255, 255, 255), 2)
            
            # Add warning for high error
            avg_error = np.mean(diff)
            if avg_error > 50:
                cv2.putText(final_viz, "ID MISMATCH!", (w//2 - 100, h + 30), 
                           font, 1, (0, 0, 255), 2)
            
            # Save visualization
            vis_filename = os.path.join(save_dir, f"frame_{frame_idx:04d}_{person_id}.jpg")
            cv2.imwrite(vis_filename, final_viz)
        
        # Calculate L1 loss
        if mask is not None:
            # Only calculate on masked regions
            nonzero_mask = (original > 0).any(axis=2)
            if not nonzero_mask.any():
                return 0.0
            
            # Extract non-zero pixels
            original_flat = original[nonzero_mask].astype(np.float64)
            compressed_flat = compressed[nonzero_mask].astype(np.float64)
            
            l1_loss = np.mean(np.abs(compressed_flat - original_flat))
        else:
            original = original.astype(np.float64)
            compressed = compressed.astype(np.float64)
            l1_loss = np.mean(np.abs(compressed - original))
        
        if self.verbose_paths:
            print(f"[GPU {self.gpu_id}] {l1_loss=}, Computing L1: {path1} <-> {path2}")

        return l1_loss
    
    def calculate_psnr(self, original, compressed, mask=None, path1="", path2="",
                       save_dir=None, frame_idx=None, person_id=None):
        """Calculate single frame PSNR, optionally with mask"""
        # Apply mask if provided
        if mask is not None:
            original = self.apply_mask_to_frame(original, mask)
            compressed = self.apply_mask_to_frame(compressed, mask)
        
        # Calculate MSE
        if mask is not None:
            # Only calculate on masked regions
            nonzero_mask = (original > 0).any(axis=2)
            if not nonzero_mask.any():
                return 0.0
            
            # Extract non-zero pixels
            original_flat = original[nonzero_mask].astype(np.float64)
            compressed_flat = compressed[nonzero_mask].astype(np.float64)
            
            mse = np.mean((compressed_flat - original_flat)**2)
        else:
            original = original.astype(np.float64)
            compressed = compressed.astype(np.float64)
            mse = np.mean((compressed - original)**2)
        
        if mse == 0:
            return float('inf')
        max_pixel = 255
        
        psnr = 20 * math.log10(max_pixel / math.sqrt(mse))

        if self.verbose_paths:
            print(f"[GPU {self.gpu_id}] {psnr=}, Computing PSNR: {path1} <-> {path2}")

        return psnr
    
    def calculate_ssim(self, original, compressed, mask=None, path1="", path2="",
                       save_dir=None, frame_idx=None, person_id=None):
        """Calculate single frame SSIM, optionally with mask"""
        
        from skimage.metrics import structural_similarity as comp_ssim
        
        # Apply mask if provided
        if mask is not None:
            original = self.apply_mask_to_frame(original, mask)
            compressed = self.apply_mask_to_frame(compressed, mask)
            
            # Check if mask has any content
            nonzero_mask = (original > 0).any(axis=2)
            if not nonzero_mask.any():
                return 0.0
        
        if original.ndim == 3:
            ssims = []
            for i in range(3):
                ssims.append(comp_ssim(original[:,:,i], compressed[:,:,i], data_range=255))

            if self.verbose_paths:
                print(f"[GPU {self.gpu_id}] ssim={np.array(ssims).mean()=}, Computing SSIM: {path1} <-> {path2}")

            return np.array(ssims).mean()
        else:
            if self.verbose_paths:
                print(f"[GPU {self.gpu_id}] ssim={comp_ssim(original, compressed, data_range=255)=}, Computing SSIM: {path1} <-> {path2}")

            return comp_ssim(original, compressed, data_range=255)
    
    # Batch processing methods
    def calculate_lpips_batch(self, img_batch, img2_batch, masks=None):
        """Batch process multiple frames with mixed precision"""
        # Prepare batch tensors
        batch_size = len(img_batch)
        img1_tensors = []
        img2_tensors = []
        
        for i in range(batch_size):
            # Convert BGR to RGB
            img1_rgb = cv2.cvtColor(img_batch[i], cv2.COLOR_BGR2RGB)
            img2_rgb = cv2.cvtColor(img2_batch[i], cv2.COLOR_BGR2RGB)
            
            # Convert to tensor and normalize to [-1, 1]
            img1_tensor = torch.from_numpy(img1_rgb).permute(2, 0, 1).float() / 255.0 * 2 - 1
            img2_tensor = torch.from_numpy(img2_rgb).permute(2, 0, 1).float() / 255.0 * 2 - 1
            
            img1_tensors.append(img1_tensor)
            img2_tensors.append(img2_tensor)
        
        # Stack into batch tensors
        img1_batch = torch.stack(img1_tensors).to(self.device)
        img2_batch = torch.stack(img2_tensors).to(self.device)
        
        # Calculate LPIPS with mixed precision
        with torch.no_grad(), autocast():
            dist = self.lpips_fn(img1_batch, img2_batch)
        
        # Process results
        results = []
        for i in range(dist.shape[0]):
            # Average across spatial dimensions
            val = dist[i].mean().cpu().item()
            results.append(float(val))
        
        return np.array(results, dtype=np.float32)
        
        
    def calculate_lpips(self, img1, img2, mask=None, path1="", path2="",
                    save_dir=None, frame_idx=None, person_id=None):
        """Calculate single frame LPIPS, optionally with mask"""
        # Apply mask if provided
        if mask is not None:
            img1 = self.apply_mask_to_frame(img1, mask)
            img2 = self.apply_mask_to_frame(img2, mask)
            
            # Check if mask has any content
            nonzero_mask = (img1 > 0).any(axis=2)
            if not nonzero_mask.any():
                return 0.0
        
        # Use batch processing if mixed precision is enabled
        if self.use_mixed_precision:
            result = self.calculate_lpips_batch([img1], [img2], [mask] if mask is not None else None)
            # Extract single result
            lpips_val = float(result[0])
        else:
            # Single image processing
            # Convert BGR to RGB
            img1_rgb = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
            img2_rgb = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
            
            # Convert to tensor and normalize to [-1, 1]
            img1_tensor = torch.from_numpy(img1_rgb).permute(2, 0, 1).float() / 255.0 * 2 - 1
            img2_tensor = torch.from_numpy(img2_rgb).permute(2, 0, 1).float() / 255.0 * 2 - 1
            
            # Add batch dimension and move to device
            img1_tensor = img1_tensor.unsqueeze(0).to(self.device)
            img2_tensor = img2_tensor.unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                dist = self.lpips_fn(img1_tensor, img2_tensor)
            
            lpips_val = float(dist.item())
        
        if self.verbose_paths:
            print(f"[GPU {self.gpu_id}] lpips={lpips_val}, Computing LPIPS: {path1} <-> {path2}")

        return lpips_val
        
    def calculate_dists(self, img1, img2, mask=None, path1="", path2=""):
        """Calculate DISTS score, optionally with mask"""
        # Apply mask if provided
        if mask is not None:
            img1 = self.apply_mask_to_frame(img1, mask)
            img2 = self.apply_mask_to_frame(img2, mask)
            
            # Check if mask has any content
            nonzero_mask = (img1 > 0).any(axis=2)
            if not nonzero_mask.any():
                return 0.0
        
        # Convert BGR to RGB
        img1_rgb = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
        img2_rgb = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
        
        # Preprocess images
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        img1_tensor = transform(Image.fromarray(img1_rgb)).unsqueeze(0).to(self.device)
        img2_tensor = transform(Image.fromarray(img2_rgb)).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            dists_score = self.dists_model(img1_tensor, img2_tensor)
        
        if self.verbose_paths:
            print(f"[GPU {self.gpu_id}] dists={dists_score.item()}, Computing DISTS: {path1} <-> {path2}")
            
        return dists_score.item()
    
    def calculate_clip_score(self, img1, img2, mask=None, path1="", path2=""):
        """Calculate CLIP score, optionally with mask"""
        # Apply mask if provided
        if mask is not None:
            img1 = self.apply_mask_to_frame(img1, mask)
            img2 = self.apply_mask_to_frame(img2, mask)
            
            # Check if mask has any content
            nonzero_mask = (img1 > 0).any(axis=2)
            if not nonzero_mask.any():
                return 0.0
        
        # Calculate CLIP score
        clip_score = self.clip_scorer.calculate_clip_score(img1, img2)
        
        if self.verbose_paths:
            print(f"[GPU {self.gpu_id}] clip_score={clip_score}, Computing CLIPScore: {path1} <-> {path2}")
            
        return clip_score
    
    def extract_i3d_features(self, video_frames, masks=None):
        """Extract I3D features from video frames for FVD"""
        if isinstance(video_frames, list):
            video_frames = np.array(video_frames)
        
        # Apply masks if provided
        if masks is not None:
            masked_frames = []
            for i, frame in enumerate(video_frames):
                if i < len(masks):
                    masked_frames.append(self.apply_mask_to_frame(frame, masks[i]))
                else:
                    masked_frames.append(frame)
            video_frames = np.array(masked_frames)
        
        # Convert to tensor: (T, H, W, C) -> (1, C, T, H, W)
        frames = torch.tensor(video_frames).permute(3, 0, 1, 2).unsqueeze(0).float() / 255.0
        
        # Resize to required input size
        frames = F.interpolate(frames, size=(frames.shape[2], 112, 112), mode='trilinear', align_corners=False)
        
        # Process in batches to avoid memory issues
        batch_size = 16
        features = []
        
        for i in range(0, frames.shape[2], batch_size):
            batch = frames[:, :, i:i+batch_size].to(self.device)
            
            with torch.no_grad():
                batch_features = self.i3d_model(batch)
                
                # Handle different output dimensions
                if batch_features.ndim == 5:
                    batch_features = batch_features.squeeze()
                    
                while batch_features.ndim < 2:
                    batch_features = batch_features.unsqueeze(0)
                    
                batch_features = batch_features.flatten(start_dim=1)
                features.append(batch_features.cpu().numpy())
                
        return np.vstack(features) if len(features) > 1 else features[0]
    
    def extract_inception_features(self, images, masks=None):
        """Extract Inception features from images for FID"""
        # Preprocessing transforms
        preprocess = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Apply masks if provided
        if masks is not None:
            masked_images = []
            for i, img in enumerate(images):
                if i < len(masks):
                    masked_images.append(self.apply_mask_to_frame(img, masks[i]))
                else:
                    masked_images.append(img)
            images = masked_images
        
        features = []
        batch_size = 64
        
        for i in tqdm(range(0, len(images), batch_size), desc=f"[GPU {self.gpu_id}] Extracting Inception features", leave=False):
            batch_images = images[i:i+batch_size]
            
            # Preprocess batch
            batch_tensors = []
            for img in batch_images:
                img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                tensor = preprocess(img_rgb)
                batch_tensors.append(tensor)
            
            batch_tensor = torch.stack(batch_tensors).to(self.device)
            
            with torch.no_grad():
                batch_features = self.inception_model(batch_tensor)
                features.append(batch_features.cpu().numpy())
        
        return np.vstack(features)
    
    def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
        """Calculate Frechet distance"""
        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)
        
        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)
        
        assert mu1.shape == mu2.shape, \
            'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, \
            'Training and test covariances have different dimensions'
        
        diff = mu1 - mu2
        
        # Add regularization to diagonal
        sigma1 += np.eye(sigma1.shape[0]) * eps
        sigma2 += np.eye(sigma2.shape[0]) * eps
        
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = ('fid calculation produces singular product; '
                   'adding %s to diagonal of cov estimates') % eps
            print(msg)
            offset = np.eye(sigma1.shape[0]) * eps
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
        
        # Handle complex numbers
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError('Imaginary component {}'.format(m))
            covmean = covmean.real
        
        tr_covmean = np.trace(covmean)
        fd = (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
        
        return fd
    
    def calculate_fvd(self, real_features, fake_features):
        """Calculate FVD"""
        mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
        mu2, sigma2 = np.mean(fake_features, axis=0), np.cov(fake_features, rowvar=False)
        
        return self.calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    
    def calculate_fid(self, real_features, fake_features):
        """Calculate FID"""
        mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
        mu2, sigma2 = np.mean(fake_features, axis=0), np.cov(fake_features, rowvar=False)
        
        return self.calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    
    def evaluate_video_pair(self, gt_path, pred_path, gt_frames, pred_frames, gt_paths=None, pred_paths=None, output_dir=None):
        """Evaluate quality metrics for a pair of videos, including full frame evaluation and per-person masked region evaluation"""
        # Align frame counts
        min_frames = min(len(gt_frames), len(pred_frames))
        gt_frames = gt_frames[:min_frames]
        pred_frames = pred_frames[:min_frames]
        
        if gt_paths is not None and pred_paths is not None:
            gt_paths = gt_paths[:min_frames]
            pred_paths = pred_paths[:min_frames]
        else:
            gt_paths = [""] * min_frames
            pred_paths = [""] * min_frames
        
        # Get video name
        video_name = os.path.basename(gt_path)
        
        # Setup visualization output directory
        vis_output_dir = None
        if self.save_visuals and output_dir:
            vis_output_dir = os.path.join(output_dir, "visualizations", video_name)
            os.makedirs(vis_output_dir, exist_ok=True)
        
        # Read human masks
        persons_mask_frames, persons_mask_paths = self.read_mask_sequence(gt_path)
        
        # Initialize results structure
        results = {
            'full_frame': {},
            'per_person_masked_region': {}
        }
        
        # 1. Full frame evaluation
        print(f"[GPU {self.gpu_id}] Computing full frame metrics...")
        
        l1_values = []
        psnr_values = []
        ssim_values = []
        lpips_values = []
        dists_values = []
        clip_score_values = []
        
        # Process frames in parallel
        def process_frame_full(i):
            l1 = self.calculate_l1_loss(gt_frames[i], pred_frames[i], None, gt_paths[i], pred_paths[i])
            psnr = self.calculate_psnr(gt_frames[i], pred_frames[i], None, gt_paths[i], pred_paths[i])
            ssim = self.calculate_ssim(gt_frames[i], pred_frames[i], None, gt_paths[i], pred_paths[i])
            lpips = self.calculate_lpips(gt_frames[i], pred_frames[i], None, gt_paths[i], pred_paths[i])
            dists = self.calculate_dists(gt_frames[i], pred_frames[i], None, gt_paths[i], pred_paths[i])
            clip_score = self.calculate_clip_score(gt_frames[i], pred_frames[i], None, gt_paths[i], pred_paths[i])
            return l1, psnr, ssim, lpips, dists, clip_score
        
        # Use thread pool for parallel processing
        max_threads = min(12, mp.cpu_count() // 8)
        with ThreadPoolExecutor(max_workers=max_threads) as executor:
            futures = [executor.submit(process_frame_full, i) for i in range(min_frames)]
            results_full = [future.result() for future in tqdm(futures, desc=f"[GPU {self.gpu_id}] Computing full frame metrics")]
        
        l1_values, psnr_values, ssim_values, lpips_values, dists_values, clip_score_values = zip(*results_full)
        
        # Calculate spatio-temporal metrics
        st_ssim = self.st_metrics.calculate_st_ssim(gt_frames, pred_frames)
        gmsd_temporal = self.st_metrics.calculate_gmsd_temporal(gt_frames, pred_frames)
        
        # Calculate FVD
        gt_i3d_features = self.extract_i3d_features(gt_frames)
        pred_i3d_features = self.extract_i3d_features(pred_frames)
        fvd = self.calculate_fvd(gt_i3d_features, pred_i3d_features)
        
        gt_inception_features = self.extract_inception_features(gt_frames)
        pred_inception_features = self.extract_inception_features(pred_frames)
        fid = self.calculate_fid(gt_inception_features, pred_inception_features)
        
        # Calculate CLIP-FID
        if CLIP_AVAILABLE:
            print(f"[GPU {self.gpu_id}] Computing CLIP-FID...")
            clip_features_gt = self.clip_scorer.get_clip_features(gt_frames)
            clip_features_pred = self.clip_scorer.get_clip_features(pred_frames)
            clip_fid = self.clip_scorer.calculate_clip_fid(clip_features_gt, clip_features_pred)
        else:
            clip_fid = 0.0
        
        # Store full frame results
        results['full_frame'] = {
            'L1': np.mean(l1_values),
            'PSNR': np.mean(psnr_values),
            'SSIM': np.mean(ssim_values),
            'LPIPS': np.mean(lpips_values),
            'DISTS': np.mean(dists_values),
            'CLIPScore': np.mean(clip_score_values),
            'ST-SSIM': st_ssim,
            'GMSD-Temporal': gmsd_temporal,
            'FVD': fvd,
            'FID': fid,
            'CLIP-FID': clip_fid,
            'num_frames': min_frames
        }
        
        # 2. Per-person masked region evaluation
        if persons_mask_frames:
            print(f"[GPU {self.gpu_id}] Computing per-person masked region metrics...")
            
            # Process each person
            for person_id, mask_frames in persons_mask_frames.items():
                print(f"[GPU {self.gpu_id}] Computing masked region metrics for {person_id}...")
                
                # Align mask frames with video frames
                mask_frames = mask_frames[:min_frames]
                
                # Setup person-specific visualization directory
                person_vis_dir = None
                if vis_output_dir:
                    person_vis_dir = os.path.join(vis_output_dir, person_id)
                    os.makedirs(person_vis_dir, exist_ok=True)
                
                l1_masked_values = []
                psnr_masked_values = []
                ssim_masked_values = []
                lpips_masked_values = []
                dists_masked_values = []
                clip_score_masked_values = []
                
                # Process frames with masks
                def process_frame_masked(i):
                    if i < len(mask_frames):
                        # Calculate metrics with mask
                        l1 = self.calculate_l1_loss(
                            gt_frames[i], pred_frames[i], mask_frames[i], 
                            gt_paths[i], pred_paths[i],
                            save_dir=person_vis_dir, frame_idx=i, person_id=person_id
                        )
                        psnr = self.calculate_psnr(
                            gt_frames[i], pred_frames[i], mask_frames[i], 
                            gt_paths[i], pred_paths[i]
                        )
                        ssim = self.calculate_ssim(
                            gt_frames[i], pred_frames[i], mask_frames[i], 
                            gt_paths[i], pred_paths[i]
                        )
                        lpips = self.calculate_lpips(
                            gt_frames[i], pred_frames[i], mask_frames[i], 
                            gt_paths[i], pred_paths[i]
                        )
                        dists = self.calculate_dists(
                            gt_frames[i], pred_frames[i], mask_frames[i],
                            gt_paths[i], pred_paths[i]
                        )
                        clip_score = self.calculate_clip_score(
                            gt_frames[i], pred_frames[i], mask_frames[i],
                            gt_paths[i], pred_paths[i]
                        )
                        return l1, psnr, ssim, lpips, dists, clip_score
                    return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
                
                # Parallel processing for masked frames
                with ThreadPoolExecutor(max_workers=max_threads) as executor:
                    futures = [executor.submit(process_frame_masked, i) for i in range(min_frames)]
                    results_masked = [future.result() for future in tqdm(futures, desc=f"[GPU {self.gpu_id}] Computing {person_id} masked region metrics")]
                
                l1_masked_values, psnr_masked_values, ssim_masked_values, lpips_masked_values, dists_masked_values, clip_score_masked_values = zip(*results_masked)
                l1_masked_values = np.array(l1_masked_values, dtype=np.float32)
                psnr_masked_values = np.array(psnr_masked_values, dtype=np.float32)
                ssim_masked_values = np.array(ssim_masked_values, dtype=np.float32)
                lpips_masked_values = np.array(lpips_masked_values, dtype=np.float32)
                dists_masked_values = np.array(dists_masked_values, dtype=np.float32)
                clip_score_masked_values = np.array(clip_score_masked_values, dtype=np.float32)

                # Calculate spatio-temporal metrics for masked regions
                masked_gt_frames = [self.apply_mask_to_frame(frame, mask) for frame, mask in zip(gt_frames, mask_frames)]
                masked_pred_frames = [self.apply_mask_to_frame(frame, mask) for frame, mask in zip(pred_frames, mask_frames)]
                
                # Spatio-temporal metrics for masked regions
                st_ssim_masked = self.st_metrics.calculate_st_ssim(masked_gt_frames, masked_pred_frames)
                gmsd_temporal_masked = self.st_metrics.calculate_gmsd_temporal(masked_gt_frames, masked_pred_frames)
                
                # FVD and FID for masked regions
                gt_i3d_features_masked = self.extract_i3d_features(gt_frames, mask_frames)
                pred_i3d_features_masked = self.extract_i3d_features(pred_frames, mask_frames)
                fvd_masked = self.calculate_fvd(gt_i3d_features_masked, pred_i3d_features_masked)
                
                gt_inception_features_masked = self.extract_inception_features(gt_frames, mask_frames)
                pred_inception_features_masked = self.extract_inception_features(pred_frames, mask_frames)
                fid_masked = self.calculate_fid(gt_inception_features_masked, pred_inception_features_masked)
                
                # CLIP-FID for masked regions
                if CLIP_AVAILABLE:
                    print(f"[GPU {self.gpu_id}] Computing CLIP-FID for masked region {person_id}...")
                    clip_features_gt_masked = self.clip_scorer.get_clip_features(masked_gt_frames)
                    clip_features_pred_masked = self.clip_scorer.get_clip_features(masked_pred_frames)
                    clip_fid_masked = self.clip_scorer.calculate_clip_fid(clip_features_gt_masked, clip_features_pred_masked)
                else:
                    clip_fid_masked = 0.0
                
                # Store per-person results
                results['per_person_masked_region'][person_id] = {
                    'L1': np.mean(l1_masked_values),
                    'PSNR': np.mean(psnr_masked_values),
                    'SSIM': np.mean(ssim_masked_values),
                    'LPIPS': np.mean(lpips_masked_values),
                    'DISTS': np.mean(dists_masked_values),
                    'CLIPScore': np.mean(clip_score_masked_values),
                    'ST-SSIM': st_ssim_masked,
                    'GMSD-Temporal': gmsd_temporal_masked,
                    'FVD': fvd_masked,
                    'FID': fid_masked,
                    'CLIP-FID': clip_fid_masked,
                    'num_frames': min_frames
                }
                
                # Create visualization video if requested
                if self.save_visuals and person_vis_dir:
                    # Create video from visualization frames
                    if os.path.exists(person_vis_dir) and len(os.listdir(person_vis_dir)) > 0:
                        try:
                            first_frame_path = os.path.join(person_vis_dir, f"frame_0000_{person_id}.jpg")
                            if not os.path.exists(first_frame_path):
                                first_frame_path = os.path.join(person_vis_dir, os.listdir(person_vis_dir)[0])
                            
                            first_frame = cv2.imread(first_frame_path)
                            h, w = first_frame.shape[:2]
                            
                            video_path = os.path.join(person_vis_dir, f"{video_name}_{person_id}.mp4")
                            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
                            video_writer = cv2.VideoWriter(video_path, fourcc, 15, (w, h))
                            
                            # Write all frames to video
                            for i in range(min_frames):
                                frame_path = os.path.join(person_vis_dir, f"frame_{i:04d}_{person_id}.jpg")
                                if os.path.exists(frame_path):
                                    frame = cv2.imread(frame_path)
                                    video_writer.write(frame)
                            
                            video_writer.release()
                            print(f"Saved visualization video to: {video_path}")
                        except Exception as e:
                            print(f"Error saving video: {e}")
            
            # Calculate average metrics across all persons
            if results['per_person_masked_region']:
                # Average metrics across all persons
                results['masked_region'] = {
                    'L1': np.mean([person_result['L1'] for person_result in results['per_person_masked_region'].values()]),
                    'PSNR': np.mean([person_result['PSNR'] for person_result in results['per_person_masked_region'].values()]),
                    'SSIM': np.mean([person_result['SSIM'] for person_result in results['per_person_masked_region'].values()]),
                    'LPIPS': np.mean([person_result['LPIPS'] for person_result in results['per_person_masked_region'].values()]),
                    'DISTS': np.mean([person_result['DISTS'] for person_result in results['per_person_masked_region'].values()]),
                    'CLIPScore': np.mean([person_result['CLIPScore'] for person_result in results['per_person_masked_region'].values()]),
                    'ST-SSIM': np.mean([person_result['ST-SSIM'] for person_result in results['per_person_masked_region'].values()]),
                    'GMSD-Temporal': np.mean([person_result['GMSD-Temporal'] for person_result in results['per_person_masked_region'].values()]),
                    'FVD': np.mean([person_result['FVD'] for person_result in results['per_person_masked_region'].values()]),
                    'FID': np.mean([person_result['FID'] for person_result in results['per_person_masked_region'].values()]),
                    'CLIP-FID': np.mean([person_result['CLIP-FID'] for person_result in results['per_person_masked_region'].values()]),
                    'num_frames': min_frames,
                    'mask_type': 'human_mask_average'
                }
        else:
            print(f"[GPU {self.gpu_id}] Warning: No human masks found, skipping masked region evaluation")
            results['per_person_masked_region'] = {}
            results['masked_region'] = None
        
        return results


def setup_dali_pipeline(self, image_paths, batch_size=16):
    from nvidia.dali import pipeline_def
    import nvidia.dali.fn as fn
    import nvidia.dali.types as types
    
    @pipeline_def
    def image_pipeline():
        jpegs, _ = fn.readers.file(files=image_paths, random_shuffle=False)
        images = fn.decoders.image(jpegs, device="mixed")
        images = fn.resize(images, resize_x=224, resize_y=224)
        return images
    
    pipe = image_pipeline(batch_size=batch_size, num_threads=8, device_id=self.gpu_id)
    return pipe

def worker_process(gpu_id, baseline_results_queue, gpu_work_queue, weights_dir, verbose_paths, save_visuals, output_dir):
    """Worker process function - handle evaluation tasks on single GPU"""
    # Set GPU device
    device = torch.device(f'cuda:{gpu_id}')
    torch.cuda.set_device(gpu_id)
    
    print(f"[GPU {gpu_id}] Worker process started")
    
    # Initialize evaluator
    evaluator = EnhancedVideoQualityEvaluator(device, gpu_id, verbose_paths, weights_dir, save_visuals)
    
    # Process tasks
    while True:
        task = None
        try:
            # Get task from queue
            task = gpu_work_queue.get(timeout=10)
            
            # Check for termination signal
            if task is None:
                gpu_work_queue.task_done()
                break
            
            gt_path, pred_path, baseline_name = task
            
            # Get video name
            video_name = os.path.basename(gt_path)
            
            # Setup baseline-specific visualization directory
            baseline_vis_dir = None
            if save_visuals and output_dir:
                baseline_vis_dir = os.path.join(output_dir, baseline_name)
                os.makedirs(baseline_vis_dir, exist_ok=True)
            
            # Read image sequences
            gt_images_path = os.path.join(gt_path, 'images')
            pred_images_path = os.path.join(pred_path, 'images')
            
            # Load frames
            gt_frames, gt_paths = evaluator.read_image_sequence(gt_images_path)
            pred_frames, pred_paths = evaluator.read_image_sequence(pred_images_path)
            
            # Evaluate video pair
            results = evaluator.evaluate_video_pair(gt_path, pred_path, gt_frames, pred_frames, gt_paths, pred_paths, baseline_vis_dir)
            
            # Put results in queue
            baseline_results_queue.put((baseline_name, video_name, results))
            
            # Mark task as done
            gpu_work_queue.task_done()
            
        except Exception as e:
            print(f"[GPU {gpu_id}] Error processing task: {e}")
            traceback.print_exc()
            
            # Mark task as done even if failed
            if task is not None:
                gpu_work_queue.task_done()


def print_average_metrics_table(all_results, baseline_names, region_type="full_frame"):
    """Print average metrics table for different methods, with indicators for higher/lower is better"""
    print(f"\n{region_type.replace('_', ' ').title()} Evaluation Average Results:")
    
    # Define metric directions (True = higher is better, False = lower is better)
    metric_direction = {
        'L1': False,
        'PSNR': True,
        'SSIM': True,
        'LPIPS': False,
        'DISTS': False,
        'CLIPScore': True,
        'ST-SSIM': True,
        'GMSD-Temporal': False,
        'FVD': False,
        'FID': False,
        'CLIP-FID': False
    }
    
    # Table headers with direction indicators
    headers = ["Method", 
               "L1↓", 
               "PSNR↑", 
               "SSIM↑", 
               "LPIPS↓",
               "DISTS↓",
               "CLIPScore↑",
               "ST-SSIM↑",
               "GMSD-T↓",
               "FVD↓", 
               "FID↓",
               "CLIP-FID↓"]
    
    print(f"| {headers[0]:<15} | {headers[1]:>8} | {headers[2]:>6} | {headers[3]:>6} | {headers[4]:>8} | {headers[5]:>6} | {headers[6]:>9} | {headers[7]:>7} | {headers[8]:>6} | {headers[9]:>5} | {headers[10]:>5} | {headers[11]:>8} |")
    print("|" + "-"*17 + "|" + "-"*10 + "|" + "-"*8 + "|" + "-"*8 + "|" + "-"*10 + "|" + "-"*8 + "|" + "-"*11 + "|" + "-"*9 + "|" + "-"*8 + "|" + "-"*7 + "|" + "-"*7 + "|" + "-"*10 + "|")
    
    # Calculate average values for each baseline
    all_avg_values = {}
    
    # Calculate averages for each baseline
    for baseline_name in baseline_names:
        if baseline_name not in all_results[region_type]:
            continue
            
        # Initialize metric accumulators
        metrics_avg = {
            'L1': [], 'PSNR': [], 'SSIM': [], 'LPIPS': [], 'DISTS': [], 'CLIPScore': [],
            'ST-SSIM': [], 'GMSD-Temporal': [], 'FVD': [], 'FID': [], 'CLIP-FID': []
        }
        
        for video_name, video_results in all_results[region_type][baseline_name].items():
            for metric in metrics_avg.keys():
                if metric in video_results:
                    metrics_avg[metric].append(video_results[metric])
        
        # Calculate averages
        avg_values = {}
        for metric, values in metrics_avg.items():
            if values:
                avg_values[metric] = np.mean(values)
            else:
                avg_values[metric] = float('nan')
        
        all_avg_values[baseline_name] = avg_values
    
    # Find best method for each metric
    best_methods = {}
    for metric, is_higher_better in metric_direction.items():
        best_value = float('-inf') if is_higher_better else float('inf')
        best_method = None
        
        for method, avg_values in all_avg_values.items():
            if metric not in avg_values or np.isnan(avg_values[metric]):
                continue
                
            if (is_higher_better and avg_values[metric] > best_value) or \
               (not is_higher_better and avg_values[metric] < best_value):
                best_value = avg_values[metric]
                best_method = method
        
        best_methods[metric] = best_method
    
    # Print results for each baseline
    for baseline_name, avg_values in all_avg_values.items():
        row = [f"{baseline_name:<15}"]
        
        for metric in ['L1', 'PSNR', 'SSIM', 'LPIPS', 'DISTS', 'CLIPScore', 'ST-SSIM', 'GMSD-Temporal', 'FVD', 'FID', 'CLIP-FID']:
            if metric not in avg_values:
                row.append("N/A")
                continue
                
            value = avg_values[metric]
            # Check if this is the best method for this metric
            is_best = baseline_name == best_methods.get(metric)
            
            # Format value based on metric type
            if np.isnan(value):
                row.append("N/A")
            elif metric == 'L1':
                row.append(f"{value:>8.4f}{'*' if is_best else ' '}")
            elif metric in ['PSNR', 'SSIM', 'DISTS']:
                row.append(f"{value:>6.2f}{'*' if is_best else ' '}")
            elif metric in ['LPIPS', 'ST-SSIM', 'GMSD-Temporal']:
                row.append(f"{value:>7.4f}{'*' if is_best else ' '}")
            elif metric == 'CLIPScore':
                row.append(f"{value:>9.4f}{'*' if is_best else ' '}")
            elif metric in ['FVD', 'FID']:
                row.append(f"{value:>5.1f}{'*' if is_best else ' '}")
            else:
                row.append(f"{value:>8.1f}{'*' if is_best else ' '}")
        
        print(f"| {row[0]} | {row[1]} | {row[2]} | {row[3]} | {row[4]} | {row[5]} | {row[6]} | {row[7]} | {row[8]} | {row[9]} | {row[10]} | {row[11]} |")
    
    # Print metric explanations
    print("\nMetric Explanations:")
    print("- L1↓: Pixel-level error (lower is better)")
    print("- PSNR↑: Peak Signal-to-Noise Ratio (higher is better)")
    print("- SSIM↑: Structural Similarity (higher is better)")
    print("- LPIPS↓: Perceptual similarity distance (lower is better)")
    print("- DISTS↓: Deep Image Structure and Texture Similarity (lower is better)")
    print("- CLIPScore↑: CLIP model semantic similarity (higher is better)")
    print("- ST-SSIM↑: Spatio-temporal structural similarity (higher is better)")
    print("- GMSD-T↓: Temporal gradient magnitude similarity deviation (lower is better)")
    print("- FVD↓: Frechet Video Distance (lower is better)")
    print("- FID↓: Frechet Image Distance (lower is better)")
    print("- CLIP-FID↓: CLIP feature-based Frechet Distance (lower is better)")
    print("- *: Indicates the best performing method for this metric")


def print_per_person_metrics_table(all_results, baseline_names):
    """Print per-person masked region evaluation metrics table"""
    if 'per_person_masked_region' not in all_results:
        return
    
    # Collect all person IDs
    all_person_ids = set()
    for baseline_name in baseline_names:
        if baseline_name in all_results.get('per_person_masked_region', {}):
            for video_results in all_results['per_person_masked_region'][baseline_name].values():
                all_person_ids.update(video_results.keys())
    
    if not all_person_ids:
        return
    
    for person_id in sorted(all_person_ids):
        print(f"\n{person_id.replace('_', ' ').title()} Evaluation Average Results:")
        
        # Table headers
        headers = ["Method", 
                "L1↓", 
                "PSNR↑", 
                "SSIM↑", 
                "LPIPS↓",
                "DISTS↓",
                "CLIPScore↑",
                "ST-SSIM↑",
                "GMSD-T↓"]
        
        print(f"| {headers[0]:<15} | {headers[1]:>8} | {headers[2]:>6} | {headers[3]:>6} | {headers[4]:>8} | {headers[5]:>6} | {headers[6]:>9} | {headers[7]:>7} | {headers[8]:>6} |")
        print("|" + "-"*17 + "|" + "-"*10 + "|" + "-"*8 + "|" + "-"*8 + "|" + "-"*10 + "|" + "-"*8 + "|" + "-"*11 + "|" + "-"*9 + "|" + "-"*8 + "|")
        
        # Calculate average values for each baseline
        all_avg_values = {}
        
        # Calculate averages for each baseline for this person
        for baseline_name in baseline_names:
            metrics_avg = {
                'L1': [], 'PSNR': [], 'SSIM': [], 'LPIPS': [], 'DISTS': [], 'CLIPScore': [],
                'ST-SSIM': [], 'GMSD-Temporal': []
            }
            
            # Collect metrics for this person across all videos
            if baseline_name in all_results.get('per_person_masked_region', {}):
                for video_name, video_results in all_results['per_person_masked_region'][baseline_name].items():
                    if person_id in video_results:
                        person_results = video_results[person_id]
                        for metric in metrics_avg.keys():
                            if metric in person_results:
                                metrics_avg[metric].append(person_results[metric])
            
            # Calculate averages
            avg_values = {}
            for metric, values in metrics_avg.items():
                if values:
                    avg_values[metric] = np.mean(values)
                else:
                    avg_values[metric] = float('nan')
            
            all_avg_values[baseline_name] = avg_values
        
        # Find best method for each metric
        metric_direction = {
            'L1': False, 'PSNR': True, 'SSIM': True, 'LPIPS': False, 
            'DISTS': False, 'CLIPScore': True, 'ST-SSIM': True, 'GMSD-Temporal': False
        }
        
        best_methods = {}
        for metric, is_higher_better in metric_direction.items():
            best_value = float('-inf') if is_higher_better else float('inf')
            best_method = None
            
            for method, avg_values in all_avg_values.items():
                if metric not in avg_values or np.isnan(avg_values[metric]):
                    continue
                    
                if (is_higher_better and avg_values[metric] > best_value) or \
                   (not is_higher_better and avg_values[metric] < best_value):
                    best_value = avg_values[metric]
                    best_method = method
            
            best_methods[metric] = best_method
        
        # Print results for each baseline
        for baseline_name, avg_values in all_avg_values.items():
            row = [f"{baseline_name:<15}"]
            
            for metric in ['L1', 'PSNR', 'SSIM', 'LPIPS', 'DISTS', 'CLIPScore', 'ST-SSIM', 'GMSD-Temporal']:
                if metric not in avg_values:
                    row.append("N/A")
                    continue
                    
                value = avg_values[metric]
                # Check if this is the best method for this metric
                is_best = baseline_name == best_methods.get(metric)
                
                # Format value based on metric type
                if np.isnan(value):
                    row.append("N/A")
                elif metric == 'L1':
                    row.append(f"{value:>8.4f}{'*' if is_best else ' '}")
                elif metric in ['PSNR', 'SSIM', 'DISTS']:
                    row.append(f"{value:>6.2f}{'*' if is_best else ' '}")
                elif metric in ['LPIPS', 'ST-SSIM', 'GMSD-Temporal']:
                    row.append(f"{value:>7.4f}{'*' if is_best else ' '}")
                else:
                    row.append(f"{value:>9.4f}{'*' if is_best else ' '}")
            
            print(f"| {row[0]} | {row[1]} | {row[2]} | {row[3]} | {row[4]} | {row[5]} | {row[6]} | {row[7]} | {row[8]} |")

def optimized_worker_process(process_id, gpu_id, tasks, baseline_results_dict, 
                            weights_dir, verbose_paths, save_visuals, 
                            vis_output_dir, batch_size, prefetch_size,
                            io_threads, use_mixed_precision, use_dali):
    """Optimized worker process function - handle evaluation tasks on single GPU"""
    import time
    
    # Set GPU device
    device = torch.device(f'cuda:{gpu_id}')
    torch.cuda.set_device(gpu_id)
    
    print(f"[Process {process_id}, GPU {gpu_id}] Worker process started")
    
    # Initialize evaluator
    evaluator = EnhancedVideoQualityEvaluator(
        device, gpu_id, verbose_paths, weights_dir, save_visuals,
        batch_size=batch_size, use_mixed_precision=use_mixed_precision,
        prefetch_size=prefetch_size, io_threads=io_threads,
        use_dali=use_dali
    )
    
    # Process all assigned tasks
    for i, task in enumerate(tasks):
        gt_path, pred_path, baseline_name, video_name = task
        try:
            start_time = time.time()
            print(f"[Process {process_id}, GPU {gpu_id}] Processing task {i+1}/{len(tasks)}: {video_name} ({baseline_name})")
            
            # Setup baseline-specific visualization directory
            baseline_vis_dir = None
            if save_visuals and vis_output_dir:
                baseline_vis_dir = os.path.join(vis_output_dir, baseline_name)
                os.makedirs(baseline_vis_dir, exist_ok=True)
            
            # Read image sequences
            gt_images_path = os.path.join(gt_path, 'images')
            pred_images_path = os.path.join(pred_path, 'images')
            
            # Load frames
            data_load_start = time.time()
            gt_frames, gt_paths = evaluator.read_image_sequence(gt_images_path)
            pred_frames, pred_paths = evaluator.read_image_sequence(pred_images_path)
            data_load_time = time.time() - data_load_start
            
            # Evaluate video pair
            eval_start = time.time()
            results = evaluator.evaluate_video_pair(gt_path, pred_path, gt_frames, pred_frames, gt_paths, pred_paths, baseline_vis_dir)
            eval_time = time.time() - eval_start
            
            # Calculate timing statistics
            total_time = time.time() - start_time
            timings = {
                'data_loading': data_load_time,
                'evaluation': eval_time,
                'total': total_time,
                'frames_per_second': len(gt_frames) / total_time
            }
            
            # Store results
            baseline_results_dict[(baseline_name, video_name)] = (results, timings)
            
            print(f"[Process {process_id}, GPU {gpu_id}] Completed video {video_name} ({baseline_name}), "
                  f"time: {total_time:.2f}s, {timings['frames_per_second']:.1f} fps")
            
        except Exception as e:
            print(f"[Process {process_id}, GPU {gpu_id}] Error processing video {video_name} ({baseline_name}): {e}")
            traceback.print_exc()
            # Store error result
            baseline_results_dict[(baseline_name, video_name)] = ({"error": str(e)}, {})
    
    print(f"[Process {process_id}, GPU {gpu_id}] All tasks completed")

def convert_numpy_types(obj):
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(item) for item in obj]
    return obj

def main():
    parser = argparse.ArgumentParser(description='Enhanced Video Quality Evaluation Tool (with advanced evaluation metrics)')
    parser.add_argument('--gt_list', type=str, required=True, help='GT video path list file')
    parser.add_argument('--baseline_roots', nargs='+', required=True, help='List of baseline method root directories')
    parser.add_argument('--baseline_names', nargs='+', required=True, help='List of baseline method names')
    parser.add_argument('--output', type=str, default='evaluation_results.json', help='Output JSON file path')
    parser.add_argument('--output_table', type=str, default='results_full_frame.csv', help='Output full frame evaluation table')
    parser.add_argument('--output_masked_table', type=str, default='results_masked_region.csv', help='Output masked region evaluation table')
    parser.add_argument('--num_gpus', type=int, default=8, help='Number of GPUs to use')
    parser.add_argument('--verbose_paths', action='store_true', help='Whether to print detailed file path information')
    parser.add_argument('--weights_dir', type=str, help='Weights file directory')
    parser.add_argument('--save_visuals', action='store_true', help='Whether to save masked visualization images')
    parser.add_argument('--vis_output_dir', type=str, default='vis_output', help='Visualization output directory')
    # Performance optimization arguments
    parser.add_argument('--processes_per_gpu', type=int, default=2, help='Number of processes per GPU')
    parser.add_argument('--batch_size', type=int, default=32, help='Model batch size')
    parser.add_argument('--prefetch_size', type=int, default=64, help='Number of frames to prefetch')
    parser.add_argument('--cpu_workers', type=int, default=48, help='Number of CPU worker threads')
    parser.add_argument('--use_mixed_precision', action='store_true', help='Use mixed precision computation')
    parser.add_argument('--io_threads', type=int, default=24, help='Number of IO operation threads')
    parser.add_argument('--use_dali', action='store_true', help='Use NVIDIA DALI for accelerated data loading')
    
    args = parser.parse_args()
    
    # Validate arguments
    if len(args.baseline_roots) != len(args.baseline_names):
        print("Error: Baseline directory and name lists have different lengths")
        sys.exit(1)
    
    print(f"Enhanced video quality evaluation started (using advanced evaluation metrics)")
    print(f"Using {args.num_gpus} GPUs, {args.processes_per_gpu} processes per GPU")
    print(f"Weights directory: {args.weights_dir or 'Using online weights'}")
    if args.save_visuals:
        print(f"Will save visualization results to: {args.vis_output_dir}")
    
    # Set environment variables for optimal performance
    os.environ['OMP_NUM_THREADS'] = str(max(1, 96 // (args.num_gpus * args.processes_per_gpu)))
    os.environ['MKL_NUM_THREADS'] = str(max(1, 96 // (args.num_gpus * args.processes_per_gpu)))
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
    
    # Load GT video list
    with open(args.gt_list, 'r') as f:
        gt_videos = [line.strip() for line in f.readlines()]
    
    print(f"Loaded {len(gt_videos)} GT videos")
    
    # Create output directories
    os.makedirs(os.path.dirname(args.output) if os.path.dirname(args.output) else '.', exist_ok=True)
    os.makedirs(os.path.dirname(args.output_table) if os.path.dirname(args.output_table) else '.', exist_ok=True)
    os.makedirs(os.path.dirname(args.output_masked_table) if os.path.dirname(args.output_masked_table) else '.', exist_ok=True)
    
    if args.save_visuals:
        os.makedirs(args.vis_output_dir, exist_ok=True)
    
    # Setup multiprocessing
    manager = mp.Manager()
    baseline_results_dict = manager.dict()
    
    # Collect all evaluation tasks
    all_tasks = []
    for gt_path in gt_videos:
        video_name = os.path.basename(gt_path)
        
        for baseline_root, baseline_name in zip(args.baseline_roots, args.baseline_names):
            pred_path = os.path.join(baseline_root, video_name)
            
            if os.path.exists(pred_path):
                task = (gt_path, pred_path, baseline_name, video_name)
                all_tasks.append(task)
            else:
                print(f"Warning: Baseline video not found {pred_path}")
    
    # Calculate total number of processes
    total_processes = args.num_gpus * args.processes_per_gpu
    
    # Group tasks by baseline for better load balancing
    tasks_by_baseline = {}
    for task in all_tasks:
        _, _, baseline_name, _ = task
        if baseline_name not in tasks_by_baseline:
            tasks_by_baseline[baseline_name] = []
        tasks_by_baseline[baseline_name].append(task)
    
    # Distribute tasks across processes
    process_tasks = [[] for _ in range(total_processes)]
    
    # Round-robin assignment by baseline
    process_idx = 0
    for baseline_name, baseline_tasks in tasks_by_baseline.items():
        # Calculate tasks per process for this baseline
        tasks_per_process = len(baseline_tasks) // total_processes + 1
        for i, task in enumerate(baseline_tasks):
            target_process = (process_idx + i // tasks_per_process) % total_processes
            process_tasks[target_process].append(task)
        process_idx = (process_idx + (len(baseline_tasks) - 1) // tasks_per_process + 1) % total_processes
    
    # Print task distribution
    for i, tasks in enumerate(process_tasks):
        gpu_id = i % args.num_gpus
        print(f"Process {i} (GPU {gpu_id}): {len(tasks)} tasks")
    
    # Start worker processes
    processes = []
    for i in range(total_processes):
        gpu_id = i % args.num_gpus
        p = mp.Process(
            target=optimized_worker_process,
            args=(
                i, gpu_id, process_tasks[i], baseline_results_dict, 
                args.weights_dir, args.verbose_paths, args.save_visuals, 
                args.vis_output_dir, args.batch_size, args.prefetch_size,
                args.io_threads, args.use_mixed_precision, args.use_dali
            )
        )
        p.daemon = True
        p.start()
        processes.append(p)
    
    # Wait for all processes to complete
    for p in processes:
        p.join()
    
    # Consolidate results
    print("\nAll evaluation tasks completed, consolidating results...")
    
    # Initialize result structure
    all_results = {
        'full_frame': {name: {} for name in args.baseline_names},
        'masked_region': {name: {} for name in args.baseline_names},
        'per_person_masked_region': {name: {} for name in args.baseline_names}
    }
    
    video_sets = set()
    timings_by_baseline = {name: [] for name in args.baseline_names}
    
    # Process results from shared dictionary
    for key, value in baseline_results_dict.items():
        baseline_name, video_name = key
        results, timings = value
        
        # Store timing information
        if timings:
            timings_by_baseline[baseline_name].append(timings)
        
        # Store full frame results
        all_results['full_frame'][baseline_name][video_name] = results['full_frame']
        
        # Store masked region results
        if results.get('masked_region') is not None:
            all_results['masked_region'][baseline_name][video_name] = results['masked_region']
        
        # Store per-person masked region results
        if results.get('per_person_masked_region'):
            if baseline_name not in all_results['per_person_masked_region']:
                all_results['per_person_masked_region'][baseline_name] = {}
            all_results['per_person_masked_region'][baseline_name][video_name] = results['per_person_masked_region']
        
        video_sets.add(video_name)
    
    # Print performance statistics
    print("\nPerformance Statistics:")
    for baseline_name, timings_list in timings_by_baseline.items():
        if not timings_list:
            continue
        
        avg_timings = {}
        for timing_dict in timings_list:
            for key, value in timing_dict.items():
                if key not in avg_timings:
                    avg_timings[key] = []
                avg_timings[key].append(value)
        
        print(f"\nBaseline method {baseline_name} average timing:")
        for key, values in avg_timings.items():
            print(f"  - {key}: {np.mean(values):.2f}s (min: {np.min(values):.2f}s, max: {np.max(values):.2f}s)")

    # Save complete results
    all_results = convert_numpy_types(all_results)
    with open(args.output, 'w') as f:
        json.dump(all_results, f, indent=2)
        
    
    print(f"\nComplete results saved to: {args.output}")
    
    # Generate CSV tables
    # Get all metrics
    all_metrics = set()
    for baseline in all_results['full_frame'].values():
        for video_results in baseline.values():
            all_metrics.update(video_results.keys())
    
    # Remove non-metric fields
    all_metrics = [m for m in all_metrics if m not in ['num_frames', 'mask_type']]
    all_metrics.sort()
    
    # Generate full frame results table
    full_frame_rows = []
    columns = ['Video'] + args.baseline_names
    
    for video_name in sorted(video_sets):
        row = [video_name]
        
        for baseline_name in args.baseline_names:
            if video_name in all_results['full_frame'].get(baseline_name, {}):
                video_results = all_results['full_frame'][baseline_name][video_name]
                metrics_str = []
                for metric in all_metrics:
                    if metric in video_results:
                        if metric in ['PSNR', 'SSIM', 'DISTS']:
                            metrics_str.append(f"{video_results[metric]:.2f}")
                        elif metric in ['FVD', 'FID', 'CLIP-FID']:
                            metrics_str.append(f"{video_results[metric]:.1f}")
                        else:
                            metrics_str.append(f"{video_results[metric]:.4f}")
                    else:
                        metrics_str.append("N/A")
                row.append(" | ".join(metrics_str))
            else:
                row.append("N/A")
        
        full_frame_rows.append(row)
    
    # Add average row
    avg_row = ["AVERAGE"]
    for baseline_name in args.baseline_names:
        avg_metrics = []
        
        for metric in all_metrics:
            values = []
            for video_name in video_sets:
                if video_name in all_results['full_frame'].get(baseline_name, {}):
                    video_results = all_results['full_frame'][baseline_name][video_name]
                    if metric in video_results:
                        values.append(video_results[metric])
            
            if values:
                avg_value = np.mean(values)
                if metric in ['PSNR', 'SSIM', 'DISTS']:
                    avg_metrics.append(f"{avg_value:.2f}")
                elif metric in ['FVD', 'FID', 'CLIP-FID']:
                    avg_metrics.append(f"{avg_value:.1f}")
                else:
                    avg_metrics.append(f"{avg_value:.4f}")
            else:
                avg_metrics.append("N/A")
        
        avg_row.append(" | ".join(avg_metrics))
    
    full_frame_rows.append(avg_row)
    
    # Save full frame table
    with open(args.output_table, 'w') as f:
        f.write(','.join(columns) + '\n')
        for row in full_frame_rows:
            f.write(','.join([str(cell).replace(',', ';') for cell in row]) + '\n')
    
    print(f"Full frame evaluation table saved to: {args.output_table}")
    
    # Generate masked region results table
    if 'masked_region' in all_results and any(all_results['masked_region'].values()):
        masked_rows = []
        
        for video_name in sorted(video_sets):
            row = [video_name]
            
            for baseline_name in args.baseline_names:
                if (baseline_name in all_results['masked_region'] and 
                    video_name in all_results['masked_region'][baseline_name]):
                    video_results = all_results['masked_region'][baseline_name][video_name]
                    metrics_str = []
                    for metric in all_metrics:
                        if metric in video_results:
                            if metric in ['PSNR', 'SSIM', 'DISTS']:
                                metrics_str.append(f"{video_results[metric]:.2f}")
                            elif metric in ['FVD', 'FID', 'CLIP-FID']:
                                metrics_str.append(f"{video_results[metric]:.1f}")
                            else:
                                metrics_str.append(f"{video_results[metric]:.4f}")
                        else:
                            metrics_str.append("N/A")
                    row.append(" | ".join(metrics_str))
                else:
                    row.append("N/A")
            
            masked_rows.append(row)
        
        # Add average row for masked region
        avg_row = ["AVERAGE"]
        for baseline_name in args.baseline_names:
            avg_metrics = []
            
            for metric in all_metrics:
                values = []
                for video_name in video_sets:
                    if (baseline_name in all_results['masked_region'] and 
                        video_name in all_results['masked_region'][baseline_name]):
                        video_results = all_results['masked_region'][baseline_name][video_name]
                        if metric in video_results:
                            values.append(video_results[metric])
                
                if values:
                    avg_value = np.mean(values)
                    if metric in ['PSNR', 'SSIM', 'DISTS']:
                        avg_metrics.append(f"{avg_value:.2f}")
                    elif metric in ['FVD', 'FID', 'CLIP-FID']:
                        avg_metrics.append(f"{avg_value:.1f}")
                    else:
                        avg_metrics.append(f"{avg_value:.4f}")
                else:
                    avg_metrics.append("N/A")
            
            avg_row.append(" | ".join(avg_metrics))
        
        masked_rows.append(avg_row)
        
        # Save masked region table
        with open(args.output_masked_table, 'w') as f:
            f.write(','.join(columns) + '\n')
            for row in masked_rows:
                f.write(','.join([str(cell).replace(',', ';') for cell in row]) + '\n')
        
        print(f"Masked region evaluation table saved to: {args.output_masked_table}")
    else:
        print("Warning: No masked region evaluation data found, skipping masked region table generation")
    
    print("\nEvaluation completed!")
    
    # Print summary tables
    print_average_metrics_table(all_results, args.baseline_names, "full_frame")

    # Print masked region summary if available
    if 'masked_region' in all_results and any(all_results['masked_region'].values()):
        print_average_metrics_table(all_results, args.baseline_names, "masked_region")
    
    # Print per-person metrics if available
    if 'per_person_masked_region' in all_results and any(all_results['per_person_masked_region'].values()):
        print_per_person_metrics_table(all_results, args.baseline_names)

if __name__ == "__main__":
    # Set multiprocessing start method
    import torch.multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    main()