# extract humanml3d dataset using smpl-h model
import sys
sys.path.insert(0, sys.path[0]+r"/../")

import pdb

import numpy
from pathlib import Path
import pickle
import os
import numpy as np
import json
from os.path import join as ospj
from config_files.data_paths import *
from utils.misc_util import have_overlap
from utils.smpl_utils import *
from tqdm import tqdm
import time
import smplx
import torch
import pickle
import trimesh
import pyrender
from pytorch3d import transforms
from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Slerp
import pandas as pd

def calc_joints_pelvis_delta(motion_data):
    pelvis_delta = primitive_utility.calc_calibrate_offset({
        'betas': torch.tensor(motion_data['betas'], device=device).reshape(1, 10),
        'gender': motion_data['gender'],
    })  # [1, 3]
    pelvis_delta = pelvis_delta.detach().cpu().numpy().squeeze()  # [3]
    num_frames = len(motion_data['trans'])
    poses = torch.tensor(motion_data['poses'], device=device)
    global_orient = transforms.axis_angle_to_matrix(poses[:, :3])  # [num_frames, 3, 3]
    body_pose = transforms.axis_angle_to_matrix(poses[:, 3:66].reshape(num_frames, 21, 3))  # [num_frames, 21, 3, 3]
    jaw_pose = transforms.axis_angle_to_matrix(poses[:, 66:69].reshape(num_frames, 1, 3))  # [num_frames, 1, 3, 3]
    leye_pose = transforms.axis_angle_to_matrix(poses[:, 69:72].reshape(num_frames, 1, 3))  # [num_frames, 1, 3, 3]
    reye_pose = transforms.axis_angle_to_matrix(poses[:, 72:75].reshape(num_frames, 1, 3))  # [num_frames, 1, 3, 3]
    lh_pose = transforms.axis_angle_to_matrix(poses[:, 75:120].reshape(num_frames, 15, 3))  # [num_frames, 15, 3, 3]
    rh_pose = transforms.axis_angle_to_matrix(poses[:, 120:].reshape(num_frames, 15, 3))  # [num_frames, 15, 3, 3]
    joints = primitive_utility.smpl_dict_inference(
        {
            'gender': motion_data['gender'],
            'betas': torch.tensor(motion_data['betas'], device=device).reshape(1, 10).repeat(num_frames, 1),
            'transl': torch.tensor(motion_data['trans'], device=device).reshape(num_frames, 3),
            'global_orient': global_orient,
            'body_pose': body_pose,
            'jaw_pose': jaw_pose,
            'left_eye_pose': leye_pose,
            'right_eye_pose': reye_pose,
            'left_hand_pose': lh_pose,
            'right_hand_pose': rh_pose,
        }, return_vertices=False
    )  # [num_frames, 52, 3]
    joints = joints.detach().cpu().numpy()  # [num_frames, 52, 3]

    return joints, pelvis_delta

def downsample(fps, target_fps, seq_data):
    
    old_trans = seq_data['trans']
    old_orient = seq_data['root_orient']
    old_pose_body = seq_data['pose_body'].reshape((-1, 21, 3))
    old_pose_lhand = seq_data['pose_lhand'].reshape((-1, 15, 3))
    old_pose_rhand = seq_data['pose_rhand'].reshape((-1, 15, 3))
    old_poses = np.concatenate([old_orient[:, np.newaxis], old_pose_body, old_pose_lhand, old_pose_rhand], axis=1)
    old_num_frames = len(seq_data['trans'])
    new_num_frames = int((old_num_frames - 1) / fps * target_fps) + 1
    if new_num_frames < 2:
        return None, None
    old_time = np.array(range(old_num_frames)) / fps
    new_time = np.array(range(new_num_frames)) / target_fps
    trans = np.zeros((new_num_frames, 3))
    poses = np.zeros((new_num_frames, n_joints, 3))
    for i in range(3):  # linear interpolation for translation
        trans[:, i] = np.interp(x=new_time, xp=old_time, fp=old_trans[:, i])
    for joint_idx in range(n_joints):
        slerp = Slerp(times=old_time, rotations=R.from_rotvec(old_poses[:, joint_idx, :]))
        poses[:, joint_idx, :] = slerp(new_time).as_rotvec()
    return trans, poses[:,0].reshape((-1, 3)), poses[:,1:22].reshape((-1, 63)), poses[:,22:37].reshape((-1, 45)), poses[:,37:].reshape((-1, 45))


model_path = body_model_dir
gender = "neutral"
body_type = 'smplx'
n_joints = 52
device = 'cuda:3'
primitive_utility = PrimitiveUtility(device=device, body_type=body_type)
enforce_zero_male=False

splits = {}
with open('./data/Inter-X/splits/train.txt', 'r') as f:
    splits['train'] = f.readlines()
with open('./data/Inter-X/splits/val.txt', 'r') as f:
    splits['val'] = f.readlines()
with open('./data/Inter-X/splits/test.txt', 'r') as f:
    splits['test'] = f.readlines()


# load interx data
raw_dataset_path = interx_dir / 'motions/'
data_list = os.listdir(raw_dataset_path)

total_amount = len(data_list)
# text1_dir = interx_dir / 'separate_annots/text1/'
# text2_dir = interx_dir / 'separate_annots/text2/'
text_dir = interx_dir / 'texts/' 

output_path = f'./data/Inter-X/seq_data_single_interaction'
if enforce_zero_male:
    output_path = f'{output_path}_zero_male'

process_transition=False
target_fps = 30
output_path = f'{output_path}_fps{target_fps}'
Path(output_path).mkdir(exist_ok=True, parents=True)

dataset = {}

for i in tqdm(range(total_amount)):
    source_path = data_list[i]
    new_name = source_path.split('.')[0]
    p1_path = os.path.join(raw_dataset_path, source_path, 'P1.npz')
    p2_path = os.path.join(raw_dataset_path, source_path, 'P2.npz')
    if not os.path.exists(p1_path) or not os.path.exists(p2_path):
        print(f"p1 or p2 path not found: {p1_path} or {p2_path}")
        continue
    seq_data_p1 = np.load(p1_path, allow_pickle=True)
    seq_data_p2 = np.load(p2_path, allow_pickle=True)
    
    fps = 120.0
    downsample_rate = int(fps / target_fps)

    betas_p1 = seq_data_p1['betas'][0, :10].astype(np.float32)
    betas_p2 = seq_data_p2['betas'][0, :10].astype(np.float32)
    gender_p1 = seq_data_p1['gender'].item()
    gender_p2 = seq_data_p2['gender'].item()
    if enforce_zero_male:
        betas_p1 = np.zeros_like(betas_p1)
        gender_p1 = 'male'
        betas_p2 = np.zeros_like(betas_p2)
        gender_p2 = 'male'
    
    if downsample_rate * target_fps != fps:
        trans_p1, root_orient_p1, pose_body_p1, pose_lhand_p1, pose_rhand_p1 = downsample(fps, target_fps, seq_data_p1)
        trans_p2, root_orient_p2, pose_body_p2, pose_lhand_p2, pose_rhand_p2 = downsample(fps, target_fps, seq_data_p2)
        if trans_p1 is None:
            print(f'sequence too short: {i}')
            continue
        trans_p1, root_orient_p1, pose_body_p1, pose_lhand_p1, pose_rhand_p1 = trans_p1.astype(np.float32), root_orient_p1.astype(np.float32), pose_body_p1.astype(np.float32), pose_lhand_p1.astype(np.float32), pose_rhand_p1.astype(np.float32)
        trans_p2, root_orient_p2, pose_body_p2, pose_lhand_p2, pose_rhand_p2 = trans_p2.astype(np.float32), root_orient_p2.astype(np.float32), pose_body_p2.astype(np.float32), pose_lhand_p2.astype(np.float32), pose_rhand_p2.astype(np.float32)
    else:
        trans_p1 = seq_data_p1['trans'][::downsample_rate].astype(np.float32)
        root_orient_p1 = seq_data_p1['root_orient'][::downsample_rate].astype(np.float32)
        pose_body_p1 = seq_data_p1['pose_body'][::downsample_rate].reshape(-1, 63).astype(np.float32)
        pose_lhand_p1 = seq_data_p1['pose_lhand'][::downsample_rate].reshape(-1, 45).astype(np.float32)
        pose_rhand_p1 = seq_data_p1['pose_rhand'][::downsample_rate].reshape(-1, 45).astype(np.float32)
        trans_p2 = seq_data_p2['trans'][::downsample_rate].astype(np.float32)
        root_orient_p2 = seq_data_p2['root_orient'][::downsample_rate].astype(np.float32)
        pose_body_p2 = seq_data_p2['pose_body'][::downsample_rate].reshape(-1, 63).astype(np.float32)
        pose_lhand_p2 = seq_data_p2['pose_lhand'][::downsample_rate].reshape(-1, 45).astype(np.float32)
        pose_rhand_p2 = seq_data_p2['pose_rhand'][::downsample_rate].reshape(-1, 45).astype(np.float32)
        if len(trans_p1)==0:
            print(f'sequence too short: {i}')
            continue
    jaw_pose_p1 = np.zeros((trans_p1.shape[0], 3), dtype=np.float32)
    leye_pose_p1 = np.zeros((trans_p1.shape[0], 3), dtype=np.float32)
    reye_pose_p1 = np.zeros((trans_p1.shape[0], 3), dtype=np.float32)
    jaw_pose_p2 = np.zeros((trans_p2.shape[0], 3), dtype=np.float32)
    leye_pose_p2 = np.zeros((trans_p2.shape[0], 3), dtype=np.float32)
    reye_pose_p2 = np.zeros((trans_p2.shape[0], 3), dtype=np.float32)
    pose_p1 = np.concatenate([root_orient_p1, pose_body_p1, jaw_pose_p1, leye_pose_p1, reye_pose_p1, pose_lhand_p1, pose_rhand_p1], axis=-1)
    pose_p2 = np.concatenate([root_orient_p2, pose_body_p2, jaw_pose_p2, leye_pose_p2, reye_pose_p2, pose_lhand_p2, pose_rhand_p2], axis=-1)

    motion_data_p1 = {'gender': gender_p1, 'betas': betas_p1, 'poses': pose_p1, 'trans': trans_p1}
    motion_data_p2 = {'gender': gender_p2, 'betas': betas_p2, 'poses': pose_p2, 'trans': trans_p2}

    # normalize translation
    base = motion_data_p1['trans'][0]
    motion_data_p2['trans'] -= base
    motion_data_p1['trans'] -= base

    # calculate joints and pelvis delta
    joints_p1, pelvis_delta_p1 = calc_joints_pelvis_delta(motion_data_p1)
    joints_p2, pelvis_delta_p2 = calc_joints_pelvis_delta(motion_data_p2)
    
    motion_data_p1['joints'] = joints_p1
    motion_data_p1['pelvis_delta'] = pelvis_delta_p1
    motion_data_p2['joints'] = joints_p2
    motion_data_p2['pelvis_delta'] = pelvis_delta_p2

    seq_data_dict = {'motion_p1': motion_data_p1, 'motion_p2': motion_data_p2, 'data_source': 'interx', 'seq_name': new_name}
    
    text_path = text_dir / f'{new_name}.txt'
    with open(text_path, 'r') as f:
        texts = f.readlines()
    frame_labels = []
    for text in texts:
        start_t = 0.0
        end_t = motion_data_p1['trans'].shape[0] / target_fps
        frame_labels.append({
            'proc_label': text,
            'start_t': start_t,
            'end_t': end_t
        })
    seq_data_dict['frame_labels_interaction'] = frame_labels
    dataset[new_name] = seq_data_dict


with open(ospj(output_path, 'all.pkl'), 'wb') as f:
    pickle.dump(dataset, f)

for split in splits:
    split_data = []
    for seq_name in splits[split]:
        seq_name = seq_name.strip()
        if seq_name in dataset:
            split_data.append(dataset[seq_name])
    with open(ospj(output_path, f'{split}.pkl'), 'wb') as f:
        pickle.dump(split_data, f)

