#!/usr/bin/env python3
"""
Script to create PosePenetrationDataset from MDM motion files

Converts MDM-generated motion files (*.pkl) to PosePenetrationDataset format.
Uses existing collision detection functionality from motion.py.
"""

import os
import pickle
import numpy as np
import torch
from pathlib import Path
import argparse
from tqdm import tqdm
import random
from sklearn.model_selection import train_test_split
from motion import Motion

def load_mdm_motion(pkl_path):
    """
    Load MDM-generated motion file
    
    Args:
        pkl_path: Path to MDM motion file
        
    Returns:
        motion: Motion object
        motion_name: Name of the motion
    """
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)
    
    # Extract motion data based on MDM output format
    if 'joints' in data:
        joints = torch.tensor(data['joints'], dtype=torch.float32)
        motion = Motion(joints=joints)
    elif 'poses' in data:
        poses = torch.tensor(data['poses'], dtype=torch.float32)
        motion = Motion(joints=poses)
    elif 'motion' in data:
        motion_data = torch.tensor(data['motion'], dtype=torch.float32)
        motion = Motion(joints=motion_data)
    else:
        raise ValueError(f"Cannot extract motion data from MDM file: {pkl_path}")
    
    motion_name = Path(pkl_path).stem
    return motion, motion_name

def extract_pose_vectors(motion):
    """
    Extract 263-dimensional pose vectors from motion
    
    Args:
        motion: Motion object
        
    Returns:
        poses: [L, 263] pose vectors
    """
    # Get SMPL parameters
    smpl_params = motion.smpl_params
    
    # Extract pose parameters (global orientation + body pose)
    poses = smpl_params['poses']  # [L, 72]
    
    # Add translation and shape parameters
    trans = smpl_params['trans']  # [L, 3]
    betas = smpl_params['betas']  # [L, 10]
    
    # Concatenate to form 263-dimensional pose vector
    # 72 (poses) + 3 (trans) + 10 (betas) = 85
    # If we need 263 dimensions, we might need to add more features
    pose_vectors = torch.cat([poses, trans, betas], dim=1)  # [L, 85]
    
    # If we need exactly 263 dimensions, we can pad or use additional features
    if pose_vectors.shape[1] < 263:
        # Pad with zeros to reach 263 dimensions
        padding = torch.zeros(pose_vectors.shape[0], 263 - pose_vectors.shape[1])
        pose_vectors = torch.cat([pose_vectors, padding], dim=1)
    elif pose_vectors.shape[1] > 263:
        # Truncate to 263 dimensions
        pose_vectors = pose_vectors[:, :263]
    
    return pose_vectors

def detect_collisions_per_joint(motion, threshold=0.01):
    """
    Detect collisions for each joint using COAP method
    
    Args:
        motion: Motion object
        threshold: Collision threshold
        
    Returns:
        joint_scores: [L, 22] collision scores for each joint
    """
    # Use the existing check_penetration method
    penetration_stats = motion.check_penetration(threshold=threshold)
    
    # Get the penetration losses for each frame
    all_penetration_losses = penetration_stats['all_penetration_losses']  # [L]
    
    # For per-joint collision detection, we need to modify the approach
    # Since the current implementation gives frame-level scores, we'll distribute them
    L = len(all_penetration_losses)
    joint_scores = np.zeros((L, 22))
    
    # For now, we'll use the frame-level penetration score for all joints
    # In a more sophisticated implementation, you would need per-joint collision detection
    for t in range(L):
        frame_score = all_penetration_losses[t].item()
        joint_scores[t, :] = frame_score
    
    return joint_scores

def process_mdm_motions(input_dir, output_dir, penetration_threshold=0.01):
    """
    Process MDM-generated motion files
    
    Args:
        input_dir: Directory containing MDM motion files
        output_dir: Output dataset directory
        penetration_threshold: Collision threshold
    """
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    
    # Create output directory structure
    (output_path / "ori_vecs").mkdir(parents=True, exist_ok=True)
    (output_path / "penetration_score").mkdir(parents=True, exist_ok=True)
    
    # Get all pkl files
    pkl_files = list(input_path.glob("*.pkl"))
    print(f"Found {len(pkl_files)} MDM motion files")
    
    motion_names = []
    
    for pkl_file in tqdm(pkl_files, desc="Processing motion files"):
        try:
            # Load motion data
            motion, motion_name = load_mdm_motion(pkl_file)
            
            # Extract pose vectors
            poses = extract_pose_vectors(motion)
            
            # Detect collisions
            joint_scores = detect_collisions_per_joint(motion, penetration_threshold)
            
            # Save pose vectors
            pose_file = output_path / "ori_vecs" / f"{motion_name}.npy"
            np.save(pose_file, poses.cpu().numpy())
            
            # Save collision scores
            score_file = output_path / "penetration_score" / f"{motion_name}_penetration_score.npy"
            np.save(score_file, joint_scores)
            
            motion_names.append(motion_name)
            
        except Exception as e:
            print(f"Error processing file {pkl_file}: {e}")
            continue
    
    # Split dataset into 60%/20%/20% train/val/test
    train_names, temp_names = train_test_split(motion_names, test_size=0.4, random_state=42)
    val_names, test_names = train_test_split(temp_names, test_size=0.5, random_state=42)
    
    # Save split files
    with open(output_path / "train.txt", 'w') as f:
        for name in train_names:
            f.write(f"{name}\n")
    
    with open(output_path / "val.txt", 'w') as f:
        for name in val_names:
            f.write(f"{name}\n")
    
    with open(output_path / "test.txt", 'w') as f:
        for name in test_names:
            f.write(f"{name}\n")
    
    print(f"Dataset creation completed:")
    print(f"  Training set: {len(train_names)} motions")
    print(f"  Validation set: {len(val_names)} motions")
    print(f"  Test set: {len(test_names)} motions")
    print(f"  Output directory: {output_path}")

def main():
    parser = argparse.ArgumentParser(description="Create PosePenetrationDataset from MDM motions")
    parser.add_argument('--input_dir', type=str, required=True, 
                       help='Directory containing MDM motion files')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Output dataset directory')
    parser.add_argument('--penetration_threshold', type=float, default=0.01,
                       help='Collision threshold')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed')
    
    args = parser.parse_args()
    
    # Set random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    # Process motion files
    process_mdm_motions(
        input_dir=args.input_dir,
        output_dir=args.output_dir,
        penetration_threshold=args.penetration_threshold
    )

if __name__ == '__main__':
    main() 