import os
import pickle
import torch
import smplx
import numpy as np
import h5py
from mld.transforms.joints2rots import config
from mld.transforms.joints2rots.smplify import SMPLify3D
from coap import attach_coap
import argparse
from pathlib import Path
from tqdm import tqdm
from motion import Motion


os.environ['PYOPENGL_PLATFORM'] = 'egl'
import pyrender
import trimesh

from human_body_prior.models.ik_engine import IK_Engine
import os
from drag_dev.fit.ik_engine_utils import SourceKeyPoints
from drag_dev.shape_optimization.coap_selfpene_loss import COAPSelfPenetrationLoss

def render_frame(vertices, faces, image_size=512, color=[0.3, 0.5, 0.8]):
    
    scene = pyrender.Scene(ambient_light=[0.4]*3, bg_color=[1.0]*3)
    
    mesh = pyrender.Mesh.from_trimesh(
        trimesh.Trimesh(vertices=vertices, faces=faces),
        material=pyrender.MetallicRoughnessMaterial(baseColorFactor=[*color, 1.0])
    )
    
    scene.add(mesh)
    
    
    cam_pose = np.eye(4)
    cam_pose[1, 3] = 1.0  
    cam_pose[2, 3] = 5.0  
    scene.add(pyrender.PerspectiveCamera(yfov=np.pi / 3.0), pose=cam_pose)
    
    
    scene.add(pyrender.DirectionalLight(color=[1.0]*3, intensity=2.0), pose=cam_pose)
    
    
    r = pyrender.OffscreenRenderer(image_size, image_size)
    color, _ = r.render(scene)
    r.delete()
    
    return color

def process_single_file(file_path, device):
    
    print(f'\nProcessing file: {file_path}')
    
    
    motion = Motion.from_pkl(str(file_path), device=device, save_back=True)
    
    
    try:
        tgt_smpl_path = 'datasets/humanml3d_drag/tgt/smpl'
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        smpl_file = os.path.join(tgt_smpl_path, f'{data["name"]}.npy')
        if os.path.exists(smpl_file):
            smpl_data = np.load(smpl_file, allow_pickle=True).item()
            tgt_shape = smpl_data['smpl_betas']
            print("Using target shape parameters")
            
            motion.smpl_params['betas'] = torch.tensor(tgt_shape[0:1]).to(device)
            motion.joints = motion.smpl2joints()  
        else:
            print("Target shape parameter file not found")
    except Exception as e:
        print(f"Error loading target shape parameters: {e}")

    
    print(f"\nCalculating self-penetration loss for {len(motion.joints)} frames...")
    stats = motion.check_penetration(threshold=0.01, batch_size=4)
    
    print(f'\nOverall Statistics:')
    print(f'Average penetration: {stats["avg_penetration"]:.4f}')
    print(f'Maximum penetration: {stats["max_penetration"]:.4f}')
    print(f'Penetration rate: {stats["penetration_rate"]:.2%} ({stats["penetration_frames"]}/{stats["total_frames"]})')

    
    vertices = motion.get_vertices().detach().cpu().numpy()
    floor_height = vertices[..., 1].min()
    vertices[..., 1] -= floor_height
    data['vertices'] = vertices

    return {
        'file_name': file_path.name,
        'motion_name': data['name'] if 'name' in data else 'unknown',
        'text': data['text'] if 'text' in data else 'unknown',
        'avg_penetration': stats["avg_penetration"],
        'max_penetration': stats["max_penetration"],
        'penetration_rate': stats["penetration_rate"],
        'penetration_frames': stats["penetration_frames"],
        'total_frames': stats["total_frames"]
    }

def main():
    
    parser = argparse.ArgumentParser(description='Evaluate self-penetration in motion generation results')
    parser.add_argument('input_path', type=str, help='Input pkl file or folder path')
    args = parser.parse_args()

    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    
    input_path = Path(args.input_path)
    results = []
    
    if input_path.is_file():
        
        if input_path.suffix == '.pkl' and '_ref.pkl' not in input_path.name:
            results.append(process_single_file(input_path, device))
    elif input_path.is_dir():
        
        pkl_files = [f for f in input_path.glob('**/*.pkl') if ('_ref.pkl' not in f.name and '_first.pkl' not in f.name)]
        for pkl_file in tqdm(pkl_files, desc="Processing files"):
            results.append(process_single_file(pkl_file, device))
    
    
    print('\n=============== Overall Statistics ===============')
    print(f'Total files processed: {len(results)}')
    if results:
        avg_penetrations = [r['avg_penetration'] for r in results]
        max_penetrations = [r['max_penetration'] for r in results]
        total_penetration_frames = sum(r['penetration_frames'] for r in results)
        total_frames = sum(r['total_frames'] for r in results)
        overall_penetration_rate = total_penetration_frames / total_frames if total_frames > 0 else 0
        
        print(f'Average penetration across all files: {np.mean(avg_penetrations):.4f}')
        print(f'Maximum penetration across all files: {np.max(max_penetrations):.4f}')
        print(f'Overall penetration rate across all files: {overall_penetration_rate:.2%} ({total_penetration_frames}/{total_frames})')
        
        
        print('\nDetailed results for each file:')
        for r in results:
            print(f"File: {r['file_name']}")
            print(f"  Motion: {r['motion_name']}")
            print(f"  Average penetration: {r['avg_penetration']:.4f}")
            print(f"  Maximum penetration: {r['max_penetration']:.4f}")
            print(f"  Penetration rate: {r['penetration_rate']:.2%} ({r['penetration_frames']}/{r['total_frames']})")
        
        
        output = {
            'summary': {
                'total_files': len(results),
                'overall_avg_penetration': float(np.mean(avg_penetrations)),
                'overall_max_penetration': float(np.max(max_penetrations)),
                'overall_penetration_rate': float(overall_penetration_rate)
            },
            'details': results
        }
        
        
        if input_path.is_file():
            save_path = input_path.parent / 'penetration_results.json'
        else:
            save_path = input_path / 'penetration_results.json'
            
        import json
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(output, f, indent=4, ensure_ascii=False)
        print(f'\nResults saved to: {save_path}')

if __name__ == "__main__":
    main() 