#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import argparse, warnings, math, traceback, time
from pathlib import Path
from typing import Dict, Tuple, List, Optional, Union
import concurrent.futures
import multiprocessing
import os
import cProfile
import pstats
import io
import gc

import cv2
try:
    from skimage.metrics import structural_similarity as ssim_sk
except ImportError:
    ssim_sk = None
import numpy as np
import pandas as pd
from tqdm import tqdm


try:
    import cupy as cp
    HAS_CUPY = True
    import cupy.cuda.runtime as runtime
    num_gpus = runtime.getDeviceCount()
except ImportError:
    HAS_CUPY = False
    cp = np
    num_gpus = 0


DEBUG = False
PROFILE = False

def debug_print(*args, **kwargs):
    if DEBUG:
        print("[DEBUG]", *args, **kwargs)


_GPU_LOCKS = {}
for i in range(max(1, num_gpus)):
    _GPU_LOCKS[i] = multiprocessing.Lock()

def get_gpu_id(worker_id, total_gpus):
    """Assign GPU ID based on worker ID"""
    if total_gpus == 0:
        return None
    return worker_id % total_gpus


def profile_function(func):
    """Decorator to profile a function"""
    def wrapper(*args, **kwargs):
        if not PROFILE:
            return func(*args, **kwargs)
            
        pr = cProfile.Profile()
        pr.enable()
        result = func(*args, **kwargs)
        pr.disable()
        s = io.StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
        ps.print_stats(20)
        debug_print(f"Profile for {func.__name__}:")
        debug_print(s.getvalue())
        return result
    return wrapper


COCO_SIGMAS = np.array(
    [ 26, 25, 25, 35, 35,
      79, 79, 72, 72, 62,
      62,107,107, 87, 87,
      89, 89], dtype=np.float32) / 1000.

def _align_sim3(src: np.ndarray, tgt: np.ndarray) -> Tuple[np.ndarray,float,np.ndarray]:
    try:

        n = min(len(src), len(tgt))
        src = src[:n]
        tgt = tgt[:n]


        valid = (~np.isnan(tgt[:, 0])) & (~np.isnan(src[:, 0]))
        if valid.sum() < 3:
            return src, 1., np.zeros(2)


        valid_src = src[valid]
        valid_tgt = tgt[valid]
        

        A = np.hstack([valid_src, np.ones((valid.sum(), 1))])
        

        X, *_ = np.linalg.lstsq(A, valid_tgt, rcond=None)
        
        s = X[:2, :].mean()
        t = X[2, :]
        

        src_hat = src * s + t
        return src_hat, s, t
    except Exception as e:
        debug_print(f"Error in _align_sim3: {e}")
        debug_print(f"src shape: {src.shape}, tgt shape: {tgt.shape}")
        return src, 1., np.zeros(2)

def mpjpe_2d(pred: np.ndarray, gt: np.ndarray) -> float:

    try:
        valid = ~np.isnan(gt[:,0])
        if valid.sum() == 0:
            return np.nan
        pred_aligned,_,_ = _align_sim3(pred, gt)
        return float(np.linalg.norm(pred_aligned[valid]-gt[valid], axis=1).mean())
    except Exception as e:
        debug_print(f"Error in mpjpe_2d: {e}")
        return np.nan

def temporal_deriv(arr: np.ndarray, fps:int=30, order:int=1) -> np.ndarray:

    try:
        if order==0: return arr
        diff = np.diff(arr, n=order, axis=0) * (fps**order)
        pad  = np.full((order,)+arr.shape[1:], np.nan, arr.dtype)
        return np.concatenate([pad, diff], axis=0)
    except Exception as e:
        debug_print(f"Error in temporal_deriv: {e}")
        return np.full_like(arr, np.nan)

def rms(arr: np.ndarray, axis=None):
    try:

        if np.all(np.isnan(arr)):
            return np.nan
        return np.sqrt(np.nanmean(arr**2, axis=axis))
    except Exception as e:
        debug_print(f"Error in rms: {e}")
        return np.nan

def oks_frame(pred: np.ndarray, gt: np.ndarray, area: float) -> float:
    """
    COCO‑style OKS for a flattened (P*J,2) frame.
    Handles multiple persons by tiling COCO_SIGMAS.
    """
    try:
        valid = ~np.isnan(gt[:, 0])
        if valid.sum() == 0:
            return np.nan


        d2 = ((pred[valid] - gt[valid]) ** 2).sum(axis=1)
        K = valid.sum()
        

        reps = int(np.ceil(K / len(COCO_SIGMAS)))
        sig = np.tile(COCO_SIGMAS, reps)[:K]
        

        oks = np.exp(-d2 / (2 * (sig ** 2) * (area + 1e-6)))
        return float(oks.mean())
    except Exception as e:
        debug_print(f"Error in oks_frame: {e}")
        return np.nan

def gaussian_heatmap_cpu(kps: np.ndarray, H:int, W:int, sigma:int=4) -> np.ndarray:
    """Original CPU implementation of gaussian heatmap"""
    try:
        canvas = np.zeros((H,W), np.float32)
        for x,y in kps:
            if np.isnan(x): continue
            cv2.circle(canvas, (int(x),int(y)), sigma, 1, -1)
        return cv2.GaussianBlur(canvas,(0,0),sigma)
    except Exception as e:
        debug_print(f"Error in gaussian_heatmap_cpu: {e}")
        return np.zeros((H,W), np.float32)

def gaussian_heatmap_gpu(kps: np.ndarray, H:int, W:int, sigma:int=4, gpu_id: Optional[int]=None) -> np.ndarray:
    """GPU-accelerated implementation of gaussian heatmap"""
    try:
        if not HAS_CUPY:
            return gaussian_heatmap_cpu(kps, H, W, sigma)
            

        if gpu_id is not None:
            cp.cuda.Device(gpu_id).use()
            

        canvas = cp.zeros((H,W), cp.float32)
        

        valid_kps = kps[~np.isnan(kps[:,0])]
        
        if len(valid_kps) > 0:

            points = valid_kps.astype(int)
            

            kernel_size = sigma * 6 + 1
            kernel_x = cp.arange(kernel_size) - kernel_size // 2
            kernel_y = cp.arange(kernel_size) - kernel_size // 2
            kernel_xx, kernel_yy = cp.meshgrid(kernel_x, kernel_y)
            kernel = cp.exp(-(kernel_xx**2 + kernel_yy**2) / (2 * sigma**2))
            

            for x, y in points:
                if 0 <= x < W and 0 <= y < H:

                    x_min = max(0, x - kernel_size // 2)
                    x_max = min(W, x + kernel_size // 2 + 1)
                    y_min = max(0, y - kernel_size // 2)
                    y_max = min(H, y + kernel_size // 2 + 1)
                    

                    k_x_min = max(0, kernel_size // 2 - x)
                    k_x_max = kernel_size - max(0, x + kernel_size // 2 + 1 - W)
                    k_y_min = max(0, kernel_size // 2 - y)
                    k_y_max = kernel_size - max(0, y + kernel_size // 2 + 1 - H)
                    

                    canvas[y_min:y_max, x_min:x_max] += kernel[k_y_min:k_y_max, k_x_min:k_x_max]
        

        return cp.asnumpy(canvas)
    except Exception as e:
        debug_print(f"Error in gaussian_heatmap_gpu: {e}")
        return gaussian_heatmap_cpu(kps, H, W, sigma)

def gaussian_heatmap(kps: np.ndarray, H:int, W:int, sigma:int=4, gpu_id: Optional[int]=None) -> np.ndarray:
    """Wrapper function to choose between CPU and GPU implementations"""
    if HAS_CUPY and gpu_id is not None:
        return gaussian_heatmap_gpu(kps, H, W, sigma, gpu_id)
    else:
        return gaussian_heatmap_cpu(kps, H, W, sigma)

def heat_ssim(pred: np.ndarray, gt: np.ndarray, H: int, W: int, gpu_id: Optional[int]=None) -> float:
    try:
        hp = gaussian_heatmap(pred, H, W, sigma=4, gpu_id=gpu_id)
        hg = gaussian_heatmap(gt, H, W, sigma=4, gpu_id=gpu_id)
        

        hp = (hp * 255).astype(np.uint8)
        hg = (hg * 255).astype(np.uint8)


        if hasattr(cv2, "quality") and hasattr(cv2.quality, "QualitySSIM_compute"):
            (ssim_val, *_), *_ = cv2.quality.QualitySSIM_compute(hp, hg)
            return float(ssim_val)

        if ssim_sk is not None:
            return float(ssim_sk(hp, hg, data_range=255))


        if hp.std() == 0 or hg.std() == 0:
            return 0.0
        return float(np.corrcoef(hp.flatten(), hg.flatten())[0, 1])
    except Exception as e:
        debug_print(f"Error in heat_ssim: {e}")
        return 0.0

def frechet_distance(mu1: np.ndarray, c1: np.ndarray,
                    mu2: np.ndarray, c2: np.ndarray, eps: float = 1e-6) -> float:
    """Optimized Fréchet distance calculation"""
    try:

        if np.isnan(mu1).any() or np.isnan(mu2).any() or np.isnan(c1).any() or np.isnan(c2).any():
            return np.nan
            

        if mu1.ndim != 1 or mu2.ndim != 1 or mu1.shape != mu2.shape:
            return np.nan
            

        if c1.ndim != 2 or c2.ndim != 2 or c1.shape != c2.shape or c1.shape[0] != mu1.shape[0]:
            return np.nan
            

        diff = mu1 - mu2
        term1 = np.dot(diff, diff)
        

        c1_reg = c1 + eps * np.eye(c1.shape[0])
        c2_reg = c2 + eps * np.eye(c2.shape[0])
        

        try:
            cov_prod = c1_reg @ c2_reg
            eigvals = np.linalg.eigvalsh(cov_prod)
            

            sqrt_eigvals = np.sqrt(np.maximum(eigvals, 0.0))
            term2 = np.trace(c1_reg) + np.trace(c2_reg) - 2 * np.sum(sqrt_eigvals)
            
            return float(max(0.0, term1 + term2))
        except np.linalg.LinAlgError:

            return float(term1 + np.trace(c1_reg) + np.trace(c2_reg))
    except Exception as e:
        debug_print(f"Error in frechet_distance: {e}")
        return np.nan


@profile_function
def _frames_to_kps(frames: List[dict], max_joints=None):

    if max_joints is None:
        max_joints = 0
        for f in frames:
            if 'bodies' in f and 'subset' in f['bodies']:
                for sub in f['bodies']['subset']:
                    max_joints = max(max_joints, len(sub))
    """
    Convert DWPose 'frames' list (exported by visualize_dwpose_with_ids_color_dataset_face.py)
    to ndarray of shape (T, P, J, 2) with NaN for missing joints.
    Only first `max_joints` body joints are kept.
    """
    try:
        T = len(frames)
        if T == 0:
            debug_print("Warning: Empty frames list")
            return np.zeros((0, 0, max_joints, 2), dtype=np.float32)
        

        if not isinstance(frames[0], dict) or 'bodies' not in frames[0]:
            debug_print(f"Warning: Invalid frame structure: {type(frames[0])}")
            return np.zeros((0, 0, max_joints, 2), dtype=np.float32)
        

        maxP = 0
        for f in frames:
            if 'bodies' in f and 'subset' in f['bodies']:
                maxP = max(maxP, len(f["bodies"]["subset"]))
        
        if maxP == 0:
            debug_print("Warning: No people found in frames")
            return np.zeros((T, 0, max_joints, 2), dtype=np.float32)
        

        kps = np.full((T, maxP, max_joints, 2), np.nan, dtype=np.float32)


        batch_size = min(100, T)
        for batch_start in range(0, T, batch_size):
            batch_end = min(batch_start + batch_size, T)
            batch_frames = frames[batch_start:batch_end]
            
            for t_offset, fr in enumerate(batch_frames):
                t = batch_start + t_offset
                
                if 'bodies' not in fr or 'candidate' not in fr['bodies'] or 'subset' not in fr['bodies']:
                    continue
                    
                H, W = fr.get("frame_shape", (1, 1))
                cand = np.asarray(fr["bodies"]["candidate"])
                subset = fr["bodies"]["subset"]
                
                for pid, sub in enumerate(subset):
                    if pid >= maxP: break
                    

                    valid_indices = [(j, int(idx)) for j, idx in enumerate(sub[:max_joints]) 
                                    if idx != -1 and int(idx) < len(cand)]
                    
                    if valid_indices:
                        js, idxs = zip(*valid_indices)
                        coords = cand[list(idxs), :2] * [W, H]
                        
                        for (j, _), (x, y) in zip(valid_indices, coords):
                            kps[t, pid, j] = [x, y]
                    
        return kps
    except Exception as e:
        debug_print(f"Error in _frames_to_kps: {e}")
        debug_print(traceback.format_exc())
        return np.zeros((0, 0, max_joints, 2), dtype=np.float32)


def load_npz_minimal(file_path):
    """Load NPZ file with minimal processing"""
    try:

        with open(file_path, 'rb') as f:
            return np.load(f, allow_pickle=True)
    except Exception as e:
        debug_print(f"Error in minimal NPZ loading: {e}")
        return None


def extract_keypoints_fast(npz_path):
    """Direct extraction of keypoints from NPZ file bypassing complex loading"""
    try:

        with np.load(npz_path, allow_pickle=True) as data:
            if 'kps' in data:
                return data['kps']
    except:
        pass
        

    try:

        result = np.full((660, 2, 18, 2), np.nan, dtype=np.float32)
        

        data = load_npz_minimal(npz_path)
        if data is None:
            return result
            

        if 'pose_by_person' in data and 'global_person_ids' in data:
            try:

                pose_by_person = data['pose_by_person'].item() if hasattr(data['pose_by_person'], 'item') else data['pose_by_person']
                global_person_ids = data['global_person_ids']
                

                if not pose_by_person or not global_person_ids:
                    return result
                    

                P = min(len(global_person_ids), 2)
                T = 660
                

                if result.shape[0] != T or result.shape[1] != P:
                    result = np.full((T, P, 18, 2), np.nan, dtype=np.float32)
                

                for p_idx, p_id in enumerate(global_person_ids[:P]):
                    if p_id not in pose_by_person:
                        continue
                        
                    person_data = pose_by_person[p_id]
                    if not person_data or len(person_data) == 0:
                        continue
                    

                    frame_count = min(T, len(person_data))
                    for t in range(frame_count):
                        if t >= len(person_data):
                            continue
                            
                        frame_data = person_data[t]
                        if not frame_data or not frame_data.get('visible', False):
                            continue
                            
                        body_data = frame_data.get('bodies', {})
                        if not body_data:
                            continue
                            
                        subset = body_data.get('subset', [])
                        candidate = body_data.get('candidate', [])
                        
                        if not subset or not candidate or len(subset) == 0 or len(candidate) == 0:
                            continue
                            

                        H, W = 1, 1
                        if 'frame_shape' in frame_data:
                            frame_shape = frame_data['frame_shape']
                            if isinstance(frame_shape, (list, tuple)) and len(frame_shape) >= 2:
                                H, W = frame_shape[:2]
                                

                        sub = subset[0]
                        if not sub:
                            continue
                            

                        for j in range(min(len(sub), 18)):
                            idx = sub[j]
                            if idx == -1:
                                continue
                                
                            try:
                                idx_int = int(idx)
                                if idx_int >= len(candidate):
                                    continue
                                    
                                coords = candidate[idx_int]
                                if len(coords) < 2:
                                    continue
                                    
                                x, y = coords[0] * W, coords[1] * H
                                result[t, p_idx, j, 0] = x
                                result[t, p_idx, j, 1] = y
                            except:

                                continue
            except Exception as e:
                debug_print(f"Error in minimal person-organized processing: {e}")
                pass
        

        elif 'frames' in data:
            try:
                frames = data['frames']
                if not frames or len(frames) == 0:
                    return result
                    

                T = min(len(frames), 660)
                maxP = 0
                

                for f in frames[:10]:
                    if 'bodies' in f and 'subset' in f['bodies']:
                        maxP = max(maxP, len(f["bodies"]["subset"]))
                
                P = min(maxP, 2)
                if P == 0:
                    return result
                    
                if result.shape[0] != T or result.shape[1] != P:
                    result = np.full((T, P, 18, 2), np.nan, dtype=np.float32)
                

                for t, fr in enumerate(frames[:T]):
                    if 'bodies' not in fr or 'candidate' not in fr['bodies'] or 'subset' not in fr['bodies']:
                        continue
                        
                    H, W = fr.get("frame_shape", (1, 1))
                    cand = fr["bodies"]["candidate"]
                    subset = fr["bodies"]["subset"]
                    
                    for pid, sub in enumerate(subset[:P]):
                        for j, idx in enumerate(sub[:18]):
                            if idx == -1:
                                continue
                                
                            try:
                                idx_int = int(idx)
                                if idx_int >= len(cand):
                                    continue
                                    
                                x, y = cand[idx_int][:2]
                                x, y = x * W, y * H
                                result[t, pid, j] = [x, y]
                            except:
                                continue
            except Exception as e:
                debug_print(f"Error in minimal frames processing: {e}")
                pass
        
        return result
    except Exception as e:
        debug_print(f"Error in fast keypoint extraction: {e}")

        return np.full((660, 2, 18, 2), np.nan, dtype=np.float32)


def estimate_area(pred_frame, gt_frame):

    valid_points = []
    
    for points in [pred_frame, gt_frame]:
        valid = ~np.isnan(points[:, 0])
        if np.any(valid):
            valid_points.extend(points[valid])
    
    if len(valid_points) > 2:
        points = np.array(valid_points)
        xs, ys = points[:, 0], points[:, 1]
        return (xs.max() - xs.min() + 1) * (ys.max() - ys.min() + 1)
    else:
        return 10000


@profile_function
def evaluate_sequence(pred_npz:Path, gt_npz:Path, fps:int, imgH:int, imgW:int, 
                     gpu_id:Optional[int]=None) -> Dict[str,float]:
    debug_print(f"Evaluating {pred_npz} against {gt_npz} (GPU ID: {gpu_id})")
    start_time = time.time()
    
    try:

        debug_print(f"Fast extracting pred keypoints: {pred_npz}")
        pred = extract_keypoints_fast(pred_npz)
        
        debug_print(f"Fast extracting GT keypoints: {gt_npz}")
        gt = extract_keypoints_fast(gt_npz)
        
        debug_print(f"Pred shape: {pred.shape}, GT shape: {gt.shape}")
        

        are_identical = np.allclose(pred, gt, equal_nan=True)
        debug_print(f"Pred and GT are identical: {are_identical}")
        

        pred_valid_count = np.sum(~np.all(np.isnan(pred[:, :, :, 0]), axis=(1, 2)))
        gt_valid_count = np.sum(~np.all(np.isnan(gt[:, :, :, 0]), axis=(1, 2)))
        debug_print(f"Valid frames in pred: {pred_valid_count}/{pred.shape[0]}")
        debug_print(f"Valid frames in GT: {gt_valid_count}/{gt.shape[0]}")
        

        for data, name in [(pred, 'pred'), (gt, 'gt')]:
            if data.size == 0 or data.ndim != 4:
                debug_print(f"WARNING: Invalid {name} data shape: {data.shape}")
                return {
                    'MPJPE_2D': np.nan,
                    'OKS': np.nan,
                    'PoseSSIM': np.nan,
                    'SmoothRMS': np.nan,
                    'TimeDyn_RMSE': np.nan,
                    'FVMD': np.nan
                }
        

        T = min(len(pred), len(gt))
        debug_print(f"Using {T} frames")
        
        if T == 0:
            debug_print("Warning: No frames to evaluate")
            return {
                'MPJPE_2D': np.nan,
                'OKS': np.nan,
                'PoseSSIM': np.nan,
                'SmoothRMS': np.nan,
                'TimeDyn_RMSE': np.nan,
                'FVMD': np.nan
            }
        

        pred = pred[:T]
        gt = gt[:T]
        

        mpjpes = []
        okss = []
        hssims = []
        smoothness_values = []
        timedyn_values = []
        fvmd_values = []
        

        chunk_size = 30
        num_persons = min(pred.shape[1], gt.shape[1])
        
        debug_print(f"Processing {num_persons} persons, chunk size: {chunk_size}")
        

        for p in range(num_persons):
            pred_valid_frames = ~np.all(np.isnan(pred[:, p, :, 0]), axis=1)
            gt_valid_frames = ~np.all(np.isnan(gt[:, p, :, 0]), axis=1)
            both_valid = pred_valid_frames & gt_valid_frames
            

            if np.any(both_valid):
                from itertools import groupby
                from operator import itemgetter
                valid_indices = np.where(both_valid)[0]
                groups = []
                for k, g in groupby(enumerate(valid_indices), lambda x: x[0] - x[1]):
                    group = list(map(itemgetter(1), g))
                    groups.append(group)
                
                longest_valid_seq = max([len(group) for group in groups], default=0)
                debug_print(f"Person {p}: longest consecutive valid sequence: {longest_valid_seq} frames")
                

                valid_percentage = np.sum(both_valid) / len(both_valid) * 100
                debug_print(f"Person {p}: {valid_percentage:.1f}% of frames are valid ({np.sum(both_valid)}/{len(both_valid)})")
            else:
                debug_print(f"Person {p}: No valid frames")
        

        for p in range(num_persons):

            try:
                person_pred = pred[:, p, :, :]
                person_gt = gt[:, p, :, :]
                

                pred_valid_frames = ~np.all(np.isnan(person_pred[:, :, 0]), axis=1)
                gt_valid_frames = ~np.all(np.isnan(person_gt[:, :, 0]), axis=1)
                
                debug_print(f"Person {p}: {np.sum(pred_valid_frames)} valid pred frames, {np.sum(gt_valid_frames)} valid GT frames")
                

                current_mpjpes = []
                current_okss = []
                current_hssims = []
                

                for chunk_start in range(0, T, chunk_size):
                    chunk_end = min(chunk_start + chunk_size, T)
                    debug_print(f"Processing frames {chunk_start}-{chunk_end} for person {p}")
                    

                    chunk_pred = person_pred[chunk_start:chunk_end]
                    chunk_gt = person_gt[chunk_start:chunk_end]
                    chunk_pred_valid = pred_valid_frames[chunk_start:chunk_end]
                    chunk_gt_valid = gt_valid_frames[chunk_start:chunk_end]
                    

                    for i in range(len(chunk_pred)):
                        p_frame = chunk_pred[i]
                        g_frame = chunk_gt[i]
                        
                        if chunk_pred_valid[i] and chunk_gt_valid[i]:

                            try:
                                mpjpe_val = mpjpe_2d(p_frame, g_frame)
                                oks_val = oks_frame(p_frame, g_frame, estimate_area(p_frame, g_frame))
                                

                                if i % 3 == 0:
                                    ssim_val = heat_ssim(p_frame, g_frame, imgH, imgW, gpu_id)
                                else:
                                    ssim_val = np.nan
                                

                                if chunk_start + i < 5:
                                    debug_print(f"Frame {chunk_start+i} metrics: MPJPE={mpjpe_val:.4f}, OKS={oks_val:.4f}, SSIM={ssim_val:.4f}")
                                    
                            except Exception as e:
                                debug_print(f"Error computing metrics for frame {chunk_start+i}: {e}")
                                mpjpe_val = np.nan
                                oks_val = np.nan
                                ssim_val = np.nan
                                
                        elif not chunk_pred_valid[i] and not chunk_gt_valid[i]:

                            mpjpe_val = 0.0
                            oks_val = 1.0
                            ssim_val = 1.0
                            if chunk_start + i < 5:
                                debug_print(f"Frame {chunk_start+i}: Both missing - perfect agreement")
                        else:

                            mpjpe_val = float('inf')
                            oks_val = 0.0
                            ssim_val = 0.0
                            if chunk_start + i < 5:
                                debug_print(f"Frame {chunk_start+i}: One missing - complete disagreement")
                                debug_print(f"  Pred valid: {chunk_pred_valid[i]}, GT valid: {chunk_gt_valid[i]}")
                        
                        current_mpjpes.append(mpjpe_val)
                        current_okss.append(oks_val)
                        current_hssims.append(ssim_val)
                    

                    del chunk_pred, chunk_gt, chunk_pred_valid, chunk_gt_valid
                    gc.collect()
                

                mpjpes.extend(current_mpjpes)
                okss.extend(current_okss)
                hssims.extend([v for v in current_hssims if not np.isnan(v)])
                

                if np.sum(pred_valid_frames & gt_valid_frames) > 5:
                    try:
                        debug_print(f"Computing motion metrics for person {p} with {np.sum(pred_valid_frames & gt_valid_frames)} valid frames")
                        

                        valid_mask = pred_valid_frames & gt_valid_frames
                        valid_indices = np.where(valid_mask)[0]
                        

                        segments = []
                        current_segment = [valid_indices[0]]
                        
                        for i in range(1, len(valid_indices)):
                            if valid_indices[i] == valid_indices[i-1] + 1:
                                current_segment.append(valid_indices[i])
                            else:
                                segments.append(current_segment)
                                current_segment = [valid_indices[i]]
                        
                        if current_segment:
                            segments.append(current_segment)
                        
                        debug_print(f"Found {len(segments)} contiguous segments")
                        for i, segment in enumerate(segments[:3]):
                            debug_print(f"  Segment {i}: {len(segment)} frames (frames {segment[0]}-{segment[-1]})")
                        

                        for segment in segments:
                            if len(segment) < 5:
                                debug_print(f"Skipping segment with only {len(segment)} frames (need >= 5)")
                                continue
                                
                            seg_pred = person_pred[segment]
                            seg_gt = person_gt[segment]
                            

                            jerk = temporal_deriv(seg_pred, fps, 3)
                            accel = temporal_deriv(seg_pred, fps, 2)
                            

                            valid_jerk = ~np.isnan(jerk).all()
                            valid_accel = ~np.isnan(accel).all()
                            
                            debug_print(f"Derivatives: jerk valid={valid_jerk}, accel valid={valid_accel}")
                            
                            if valid_jerk:
                                jerk_rms = rms(jerk)
                                debug_print(f"Jerk RMS: {jerk_rms}")
                                smoothness_values.append(jerk_rms)
                            
                            if valid_accel:
                                accel_rms = rms(accel)
                                debug_print(f"Accel RMS: {accel_rms}")
                                timedyn_values.append(accel_rms)
                            

                            try:
                                vel_pred = temporal_deriv(seg_pred, fps, 1)[1:]
                                vel_gt = temporal_deriv(seg_gt, fps, 1)[1:]
                                

                                mask_pred = ~np.isnan(vel_pred.reshape(-1, 2)).any(axis=1)
                                mask_gt = ~np.isnan(vel_gt.reshape(-1, 2)).any(axis=1)
                                mask = mask_pred & mask_gt
                                
                                debug_print(f"Velocity: {np.sum(mask)}/{len(mask)} valid points")
                                
                                if np.sum(mask) > 10:
                                    vel_pred_clean = vel_pred.reshape(-1, 2)[mask]
                                    vel_gt_clean = vel_gt.reshape(-1, 2)[mask]
                                    
                                    mu_p = np.mean(vel_pred_clean, axis=0)
                                    mu_g = np.mean(vel_gt_clean, axis=0)
                                    
                                    debug_print(f"Velocity means - pred: {mu_p}, gt: {mu_g}")
                                    

                                    reg = 1e-6 * np.eye(2)
                                    cov_p = np.cov(vel_pred_clean.T) + reg
                                    cov_g = np.cov(vel_gt_clean.T) + reg
                                    
                                    debug_print(f"Covariance matrices - pred: {cov_p}, gt: {cov_g}")
                                    
                                    fvmd_val = frechet_distance(mu_p, cov_p, mu_g, cov_g)
                                    debug_print(f"FVMD value: {fvmd_val}")
                                    
                                    if not np.isnan(fvmd_val) and not np.isinf(fvmd_val):
                                        fvmd_values.append(fvmd_val)
                                else:
                                    debug_print(f"Not enough valid velocity points ({np.sum(mask)}) for FVMD calculation")
                            except Exception as e:
                                debug_print(f"Error computing FVMD for person {p}: {e}")
                                debug_print(traceback.format_exc())
                    except Exception as e:
                        debug_print(f"Error computing motion metrics for person {p}: {e}")
                        debug_print(traceback.format_exc())
                else:
                    debug_print(f"Not enough valid frames for motion metrics: {np.sum(pred_valid_frames & gt_valid_frames)} (need >5)")
            except Exception as e:
                debug_print(f"Error processing person {p}: {e}")
                debug_print(traceback.format_exc())
        

        del pred, gt
        gc.collect()
        

        if not mpjpes and not okss and not hssims:
            debug_print("No valid metrics collected")
            results = {
                'MPJPE_2D': np.nan,
                'OKS': np.nan,
                'PoseSSIM': np.nan,
                'SmoothRMS': np.nan,
                'TimeDyn_RMSE': np.nan,
                'FVMD': np.nan
            }
        else:

            if mpjpes:
                debug_print(f"MPJPE stats: min={np.nanmin(mpjpes):.4f}, max={np.nanmax(mpjpes):.4f}, mean={np.nanmean(mpjpes):.4f}, count={len(mpjpes)}")
            if okss:
                debug_print(f"OKS stats: min={np.nanmin(okss):.4f}, max={np.nanmax(okss):.4f}, mean={np.nanmean(okss):.4f}, count={len(okss)}")
            if hssims:
                debug_print(f"SSIM stats: min={np.nanmin(hssims):.4f}, max={np.nanmax(hssims):.4f}, mean={np.nanmean(hssims):.4f}, count={len(hssims)}")
            if smoothness_values:
                debug_print(f"Smoothness stats: min={np.nanmin(smoothness_values):.4f}, max={np.nanmax(smoothness_values):.4f}, mean={np.nanmean(smoothness_values):.4f}, count={len(smoothness_values)}")
            if timedyn_values:
                debug_print(f"TimeDyn stats: min={np.nanmin(timedyn_values):.4f}, max={np.nanmax(timedyn_values):.4f}, mean={np.nanmean(timedyn_values):.4f}, count={len(timedyn_values)}")
            if fvmd_values:
                debug_print(f"FVMD stats: min={np.nanmin(fvmd_values):.4f}, max={np.nanmax(fvmd_values):.4f}, mean={np.nanmean(fvmd_values):.4f}, count={len(fvmd_values)}")
            
            results = {
                'MPJPE_2D': handle_inf_values(mpjpes),
                'OKS': float(np.nanmean(okss)),
                'PoseSSIM': float(np.nanmean(hssims)),
                'SmoothRMS': float(np.nanmean(smoothness_values)) if smoothness_values else np.nan,
                'TimeDyn_RMSE': float(np.nanmean(timedyn_values)) if timedyn_values else np.nan,
                'FVMD': float(np.nanmean(fvmd_values)) if fvmd_values else np.nan
            }
        
        elapsed = time.time() - start_time
        debug_print(f"Sequence evaluation completed in {elapsed:.2f}s")
        debug_print(f"Results: {results}")
        

        gc.collect()
        
        return results
    except Exception as e:
        debug_print(f"Error in evaluate_sequence: {e}")
        debug_print(traceback.format_exc())
        return {
            'MPJPE_2D': np.nan,
            'OKS': np.nan,
            'PoseSSIM': np.nan,
            'SmoothRMS': np.nan,
            'TimeDyn_RMSE': np.nan,
            'FVMD': np.nan
        }


def handle_inf_values(values):
    values = np.array(values)
    if np.any(np.isinf(values)):

        finite_values = values[np.isfinite(values)]
        if finite_values.size > 0:
            finite_max = np.max(finite_values)
            values[np.isinf(values)] = finite_max * 10
        else:
            values[np.isinf(values)] = 1000
    return float(np.nanmean(values))


def estimate_sequence_complexity(seq_path):
    """Estimate sequence complexity based on file size and content"""
    try:
        poses_npz = seq_path / 'poses.npz'
        if not poses_npz.exists():
            return 0
            

        complexity = poses_npz.stat().st_size
        

        try:
            with np.load(poses_npz, allow_pickle=True, mmap_mode='r') as data:

                if 'frames' in data:
                    try:
                        complexity += len(data['frames']) * 100
                    except:
                        pass
                elif 'kps' in data:
                    try:
                        complexity += data['kps'].shape[0] * 100
                    except:
                        pass
        except:
            pass
            
        return complexity
    except:
        return 0


def sequence_worker_fn(queue, npz_pred, npz_gt, fps, imgH, imgW, gpu_id, seq_name):
    try:

        res = evaluate_sequence(npz_pred, npz_gt, fps, imgH, imgW, gpu_id)
        res['seq'] = seq_name
        

        if not all(np.isnan(v) for v in res.values() if isinstance(v, (int, float))):
            queue.put(res)
        else:
            queue.put(None)
    except Exception as e:
        debug_print(f"Worker error processing {seq_name}: {e}")
        debug_print(traceback.format_exc())
        queue.put(None)


def process_sequence_with_gpu(args):
    """Process a sequence with GPU support"""
    seq_name, pred_root, gt_root, fps, imgH, imgW, gpu_id, gpu_lock = args
    

    if gpu_id is not None and gpu_lock is not None:
        try:
            gpu_lock.acquire()
            if HAS_CUPY:
                cp.cuda.Device(gpu_id).use()
                debug_print(f"Using GPU {gpu_id} for sequence {seq_name}")
        except Exception as e:
            debug_print(f"Error setting GPU {gpu_id}: {e}")
    
    try:

        seq_pred = Path(pred_root) / seq_name
        seq_gt = Path(gt_root) / seq_name
        
        npz_pred = seq_pred / 'poses.npz'
        npz_gt = seq_gt / 'poses.npz'
        

        if not npz_pred.exists() or not npz_gt.exists():
            debug_print(f"Missing NPZ file for {seq_name}")
            return None
        

        import multiprocessing
        ctx = multiprocessing.get_context('spawn')
        result_queue = ctx.Queue()
        

        worker = ctx.Process(
            target=sequence_worker_fn, 
            args=(result_queue, npz_pred, npz_gt, fps, imgH, imgW, gpu_id, seq_name)
        )
        worker.start()
        

        max_time = 300
        worker.join(timeout=max_time)
        
        if worker.is_alive():
            debug_print(f"Evaluation of {seq_name} timed out after {max_time}s, terminating process")
            worker.terminate()
            worker.join(1)
            if worker.is_alive():
                worker.kill()
            return None
        

        if not result_queue.empty():
            return result_queue.get()
        else:
            debug_print(f"No result returned for {seq_name}")
            return None
            
    except Exception as e:
        debug_print(f"Error processing sequence {seq_name}: {e}")
        debug_print(traceback.format_exc())
        return None
    finally:

        if gpu_id is not None and gpu_lock is not None:
            try:
                gpu_lock.release()
                debug_print(f"Released GPU {gpu_id} for sequence {seq_name}")
            except:
                pass
        

        gc.collect()


def process_batch(batch_sequences, pred_root, gt_root, fps, imgH, imgW, gpu_ids=None):
    """Process a batch of sequences with multiple workers"""
    results = []
    

    args_list = []
    for i, seq_name in enumerate(batch_sequences):
        gpu_id = None
        gpu_lock = None
        
        if gpu_ids and HAS_CUPY:
            gpu_id = gpu_ids[i % len(gpu_ids)]
            gpu_lock = _GPU_LOCKS[gpu_id]
            
        args_list.append((seq_name, pred_root, gt_root, fps, imgH, imgW, gpu_id, gpu_lock))
    

    with concurrent.futures.ProcessPoolExecutor(max_workers=len(args_list)) as executor:
        futures = [executor.submit(process_sequence_with_gpu, args) for args in args_list]
        
        for future in concurrent.futures.as_completed(futures):
            try:
                result = future.result()
                if result is not None:
                    results.append(result)
            except Exception as e:
                debug_print(f"Error in future execution: {e}")
    
    return results


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--pred_root", required=True, help="")
    ap.add_argument("--gt_root",   required=True, help="")
    ap.add_argument("--out_csv",   required=True)
    ap.add_argument("--fps", type=int, default=30)
    ap.add_argument("--imgH", type=int, default=512)
    ap.add_argument("--imgW", type=int, default=512)
    ap.add_argument("--debug", action="store_true", help="Enable debug mode")
    ap.add_argument("--profile", action="store_true", help="Enable profiling")
    ap.add_argument("--num_workers", type=int, default=min(32, os.cpu_count() - 4), 
                   help="Number of worker processes (default: min(32, cpu_count-4))")
    ap.add_argument("--gpus", type=str, default="",
                   help="Comma-separated list of GPU IDs to use")
    ap.add_argument("--batch_size", type=int, default=10,
                   help="Number of sequences to process in each batch")
    ap.add_argument("--max_time", type=int, default=3600,
                   help="Maximum time in seconds to run evaluation (default: 1 hour)")
    ap.add_argument("--skip_heat_ssim", action="store_true", 
                   help="Skip heat map SSIM calculation to save time")
    ap.add_argument("--watchdog", action="store_true",
                   help="Enable watchdog to kill stuck processes")
    args = ap.parse_args()
    
    global DEBUG, PROFILE
    DEBUG = args.debug
    PROFILE = args.profile
    start_time = time.time()


    gpu_ids = []
    if args.gpus and HAS_CUPY:
        gpu_ids = [int(id) for id in args.gpus.split(',') if id.strip()]
        debug_print(f"Using GPUs: {gpu_ids}")
    else:
        debug_print("No GPUs specified or cupy not available, using CPU only")


    gt_seqs = [d for d in Path(args.gt_root).glob('*') if d.is_dir()]
    gt_seq_names = [d.name for d in gt_seqs]
    
    print(f"Found {len(gt_seq_names)} GT sequence directories")
    print(f"Using {args.num_workers} worker processes")
    

    valid_seqs = []
    for seq_name in gt_seq_names:
        gt_seq_path = Path(args.gt_root) / seq_name / 'poses.npz'
        pred_seq_path = Path(args.pred_root) / seq_name / 'poses.npz'
        
        if gt_seq_path.exists() and pred_seq_path.exists():
            valid_seqs.append(seq_name)
    
    print(f"Found {len(valid_seqs)} valid sequences for evaluation")
    
    if not valid_seqs:
        print("No valid sequences found. Exiting.")
        df = pd.DataFrame(columns=['seq', 'MPJPE_2D', 'OKS', 'PoseSSIM', 'SmoothRMS', 'TimeDyn_RMSE', 'FVMD'])
        df.to_csv(args.out_csv, index=False)
        return
    

    if DEBUG and len(valid_seqs) > 5:
        import random
        valid_seqs = random.sample(valid_seqs, 5)
        print(f"DEBUG MODE: Using 5 random sequences: {valid_seqs}")
    

    all_results = []
    batch_size = min(args.batch_size, len(valid_seqs))
    

    batches = [valid_seqs[i:i+batch_size] for i in range(0, len(valid_seqs), batch_size)]
    

    total_processed = 0
    total_batches = len(batches)
    failed_sequences = []
    

    if args.watchdog:
        import threading
        watchdog_event = threading.Event()
        
        def watchdog_timer():
            time.sleep(args.max_time)
            if not watchdog_event.is_set():
                print(f"\nWatchdog triggered after {args.max_time}s. Terminating evaluation.")
                os._exit(1)
                
        watchdog_thread = threading.Thread(target=watchdog_timer, daemon=True)
        watchdog_thread.start()
    
    try:

        for batch_idx, batch_seqs in enumerate(batches):
            print(f"\nProcessing batch {batch_idx+1}/{total_batches} with {len(batch_seqs)} sequences")
            

            args_list = []
            for i, seq_name in enumerate(batch_seqs):
                gpu_id = None
                gpu_lock = None
                
                if gpu_ids and HAS_CUPY:
                    gpu_id = gpu_ids[i % len(gpu_ids)]
                    gpu_lock = _GPU_LOCKS[gpu_id]
                    
                args_list.append((seq_name, args.pred_root, args.gt_root, args.fps, args.imgH, args.imgW, gpu_id, gpu_lock))
            

            with concurrent.futures.ProcessPoolExecutor(max_workers=min(args.num_workers, len(args_list))) as executor:
                future_to_seq = {executor.submit(process_sequence_with_gpu, args): args[0] for args in args_list}
                
                for future in tqdm(concurrent.futures.as_completed(future_to_seq), 
                                  total=len(future_to_seq), 
                                  desc=f"Batch {batch_idx+1}/{total_batches}"):
                    seq_name = future_to_seq[future]
                    try:
                        result = future.result()
                        if result is not None:
                            all_results.append(result)
                            total_processed += 1
                        else:
                            failed_sequences.append(seq_name)
                            print(f"Warning: Failed to process sequence {seq_name}")
                    except Exception as e:
                        failed_sequences.append(seq_name)
                        print(f"Exception processing {seq_name}: {e}")
            

            if all_results:
                batch_df = pd.DataFrame(all_results)

                batch_csv = Path(args.out_csv).with_suffix('.partial.csv')
                batch_df.to_csv(batch_csv, index=False)
                print(f"Saved {len(all_results)} results to {batch_csv}")
            

            elapsed = time.time() - start_time
            if elapsed > args.max_time:
                print(f"\nReached time limit of {args.max_time}s. Processed {total_processed}/{len(valid_seqs)} sequences.")
                break
                

        if args.watchdog:
            watchdog_event.set()
            

        if not all_results:
            print("No valid sequences were evaluated. Check if poses.npz files exist and are valid.")
            df = pd.DataFrame(columns=['seq', 'MPJPE_2D', 'OKS', 'PoseSSIM', 'SmoothRMS', 'TimeDyn_RMSE', 'FVMD'])
        else:
            df = pd.DataFrame(all_results)

            try:
                mean_row = df.mean(numeric_only=True)
                mean_dict = {'seq': 'MEAN'}
                for col in mean_row.index:
                    mean_dict[col] = mean_row[col]
                df = pd.concat([df, pd.DataFrame([mean_dict])], ignore_index=True)
            except Exception as e:
                print(f"Error calculating mean: {e}")
            
        df.to_csv(args.out_csv, index=False)
        

        success_rate = (total_processed / len(valid_seqs)) * 100
        total_time = time.time() - start_time
        
        print(f"\nEvaluation complete!")
        print(f"Total time: {total_time:.2f}s")
        print(f"Sequences processed: {total_processed}/{len(valid_seqs)} ({success_rate:.1f}%)")
        print(f"Failed sequences: {len(failed_sequences)}")
        
        if failed_sequences:
            with open(Path(args.out_csv).with_suffix('.failed.txt'), 'w') as f:
                f.write('\n'.join(failed_sequences))
            print(f"Failed sequence list saved to {Path(args.out_csv).with_suffix('.failed.txt')}")
        
        print(df.tail(3))
        print(f"Final results saved to {args.out_csv}")
        
    except KeyboardInterrupt:
        print("\nEvaluation interrupted by user. Saving partial results...")
        if all_results:
            partial_df = pd.DataFrame(all_results)
            partial_csv = Path(args.out_csv).with_suffix('.partial.csv')
            partial_df.to_csv(partial_csv, index=False)
            print(f"Saved partial results to {partial_csv}")
        

        if args.watchdog:
            watchdog_event.set()
            
        sys.exit(1)

if __name__ == "__main__":
    main()