# extract humanml3d dataset using smpl-h model
import sys
sys.path.insert(0, sys.path[0]+r"/../")

import pdb

import numpy
from pathlib import Path
import pickle
import os
import numpy as np
import json
from os.path import join as ospj
from config_files.data_paths import *
from utils.misc_util import have_overlap
from utils.smpl_utils import *
from tqdm import tqdm
import time
import smplx
import torch
import pickle
import trimesh
import pyrender
from pytorch3d import transforms
from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Slerp
import pandas as pd

def calc_joints_pelvis_delta(motion_data):
    pelvis_delta = primitive_utility.calc_calibrate_offset({
        'betas': torch.tensor(motion_data['betas'], device=device).reshape(1, 10),
        'gender': motion_data['gender'],
    })  # [1, 3]
    pelvis_delta = pelvis_delta.detach().cpu().numpy().squeeze()  # [3]
    num_frames = len(motion_data['trans'])
    poses = torch.tensor(motion_data['poses'], device=device)
    global_orient = transforms.axis_angle_to_matrix(poses[:, :3])  # [num_frames, 3, 3]
    body_pose = transforms.axis_angle_to_matrix(poses[:, 3:66].reshape(num_frames, 21, 3))  # [num_frames, 21, 3, 3]
    joints = primitive_utility.smpl_dict_inference(
        {
            'gender': motion_data['gender'],
            'betas': torch.tensor(motion_data['betas'], device=device).reshape(1, 10).repeat(num_frames, 1),
            'transl': torch.tensor(motion_data['trans'], device=device).reshape(num_frames, 3),
            'global_orient': global_orient,
            'body_pose': body_pose,
        }, return_vertices=False
    )  # [num_frames, 22, 3]
    joints = joints.detach().cpu().numpy()  # [num_frames, 22, 3]

    return joints, pelvis_delta

def downsample(fps, target_fps, seq_data):
    old_trans = seq_data['trans']
    old_orient = seq_data['root_orient']
    old_pose_body = seq_data['pose_body'][:, :63].reshape((-1, 21, 3))
    old_poses = np.concatenate([old_orient[:, np.newaxis], old_pose_body], axis=1)
    old_num_frames = len(seq_data['trans'])
    new_num_frames = int((old_num_frames - 1) / fps * target_fps) + 1
    if new_num_frames < 2:
        return None, None
    old_time = np.array(range(old_num_frames)) / fps
    new_time = np.array(range(new_num_frames)) / target_fps
    trans = np.zeros((new_num_frames, 3))
    poses = np.zeros((new_num_frames, 22, 3))
    for i in range(3):  # linear interpolation for translation
        trans[:, i] = np.interp(x=new_time, xp=old_time, fp=old_trans[:, i])
    for joint_idx in range(22):
        slerp = Slerp(times=old_time, rotations=R.from_rotvec(old_poses[:, joint_idx, :]))
        poses[:, joint_idx, :] = slerp(new_time).as_rotvec()
    return trans, poses[:,0].reshape((-1, 3)), poses[:,1:].reshape((-1, 63))

def mirror_sequence(trans, poses):
    """ left-right mirror of translation and joint rotations"""
    trans_mirror = deepcopy(trans)
    trans_mirror[:, 0] = -trans_mirror[:, 0]
    right_chain = [2, 5, 8, 11, 14, 17, 19, 21]
    left_chain = [1, 4, 7, 10, 13, 16, 18, 20]

    # mirror joint rotations
    def mirror_orient(poses, right_chain, left_chain):
        poses_mirror = deepcopy(poses).reshape((-1, 3))
        poses_quaternion = R.from_rotvec(poses_mirror).as_quat()
        poses_quaternion *= np.array([1, 1, -1, 1], dtype=np.float32).reshape((1, 4))
        poses_mirror = R.from_quat(poses_quaternion).as_rotvec().astype(np.float32)
        poses_mirror = poses_mirror.reshape((-1, 22, 3))
        poses_mirror[right_chain], poses_mirror[left_chain] = poses_mirror[left_chain], poses_mirror[right_chain]
        poses_mirror = poses_mirror.reshape((-1, 66))
        return poses_mirror

    poses_mirror = mirror_orient(poses, right_chain, left_chain)
    poses_mirror = poses_mirror.reshape((-1, 66))
    # pdb.set_trace()

    return trans_mirror, poses_mirror


model_path = body_model_dir
gender = "male"
device = 'cuda'
primitive_utility = PrimitiveUtility(device=device, body_type='smplh')
# enforce_zero_male=True
enforce_zero_male=False

splits = {}
with open('./data/InterHuman/split/train.txt', 'r') as f:
    splits['train'] = f.readlines()
with open('./data/InterHuman/split/val.txt', 'r') as f:
    splits['val'] = f.readlines()
with open('./data/InterHuman/split/test.txt', 'r') as f:
    splits['test'] = f.readlines()


# load interhuman data
raw_dataset_path = interhuman_dir / 'motions/'
data_list = os.listdir(raw_dataset_path)

total_amount = len(data_list)
text_dir = interhuman_dir / 'annots/'

output_path = f'./data/InterHuman/seq_data'
if enforce_zero_male:
    output_path = f'{output_path}_zero_male'
Path(output_path).mkdir(exist_ok=True, parents=True)

process_transition=False
# target_fps = 20
target_fps = 30
# output_path = f'{output_path}_fps{target_fps}'
output_path = f'{output_path}_fps{target_fps}_test'
os.makedirs(output_path, exist_ok=True)
dataset = {}

for i in tqdm(range(total_amount)):
    source_path = data_list[i]
    new_name = source_path.split('.')[0]
    seq_path = os.path.join(raw_dataset_path, source_path)
    if not os.path.exists(seq_path):
        print(f"seq_path not found: {seq_path}")
        continue
    with open(seq_path, 'rb') as f:
        seq_data = pickle.load(f)
    
    fps = 60.0
    downsample_rate = int(fps / target_fps)

    betas_p1 = seq_data['person1']['betas'][:10].astype(np.float32)
    betas_p2 = seq_data['person2']['betas'][:10].astype(np.float32)
    gender_p1 = seq_data['person1']['gender']
    gender_p2 = seq_data['person2']['gender']
    if enforce_zero_male:
        betas_p1 = np.zeros_like(betas_p1)
        gender_p1 = 'male'
        betas_p2 = np.zeros_like(betas_p2)
        gender_p2 = 'male'
        
    if downsample_rate * target_fps != fps:
        trans_p1, root_orient_p1, pose_body_p1 = downsample(fps, target_fps, seq_data['person1'])
        trans_p2, root_orient_p2, pose_body_p2 = downsample(fps, target_fps, seq_data['person2'])
        if trans_p1 is None:
            print(f'sequence too short: {i}')
            continue
        trans_p1, root_orient_p1, pose_body_p1 = trans_p1.astype(np.float32), root_orient_p1.astype(np.float32), pose_body_p1.astype(np.float32)
        trans_p2, root_orient_p2, pose_body_p2 = trans_p2.astype(np.float32), root_orient_p2.astype(np.float32), pose_body_p2.astype(np.float32)
    else:
        trans_p1 = seq_data['person1']['trans'][::downsample_rate].astype(np.float32)
        root_orient_p1 = seq_data['person1']['root_orient'][::downsample_rate].astype(np.float32)
        pose_body_p1 = seq_data['person1']['pose_body'][::downsample_rate, :63].astype(np.float32)
        trans_p2 = seq_data['person2']['trans'][::downsample_rate].astype(np.float32)
        root_orient_p2 = seq_data['person2']['root_orient'][::downsample_rate].astype(np.float32)
        pose_body_p2 = seq_data['person2']['pose_body'][::downsample_rate, :63].astype(np.float32)
        if len(trans_p1)==0:
            print(f'sequence too short: {i}')
            continue
    pose_p1 = np.concatenate([root_orient_p1, pose_body_p1], axis=-1)
    pose_p2 = np.concatenate([root_orient_p2, pose_body_p2], axis=-1)

    motion_data_p1 = {'gender': gender_p1, 'betas': betas_p1, 'poses': pose_p1, 'trans': trans_p1}
    motion_data_p2 = {'gender': gender_p2, 'betas': betas_p2, 'poses': pose_p2, 'trans': trans_p2}

    joints_p1, pelvis_delta_p1 = calc_joints_pelvis_delta(motion_data_p1)
    joints_p2, pelvis_delta_p2 = calc_joints_pelvis_delta(motion_data_p2)
    motion_data_p1['joints'] = joints_p1
    motion_data_p1['pelvis_delta'] = pelvis_delta_p1
    motion_data_p2['joints'] = joints_p2
    motion_data_p2['pelvis_delta'] = pelvis_delta_p2

    seq_data_dict = {'motion_p1': motion_data_p1, 'motion_p2': motion_data_p2, 'data_source': 'interhuman', 'seq_name': new_name}
    text_path = text_dir / f'{new_name}.txt'
    with open(text_path, 'r') as f:
        texts = f.readlines()
    frame_labels = []
    for text in texts:
        start_t = 0.0
        end_t = motion_data_p1['trans'].shape[0] / target_fps
        frame_labels.append({
            'proc_label': text,
            'start_t': start_t,
            'end_t': end_t
        })
    seq_data_dict['frame_labels'] = frame_labels
    dataset[new_name] = seq_data_dict

    # mirror motion: fails for smpl
    # trans_mirror, poses_mirror = mirror_sequence(trans, poses)
    # motion_data = {'gender': gender, 'betas': betas, 'poses': poses_mirror, 'trans': trans_mirror}
    # export_smpl(trans_mirror, poses_mirror, gender, betas, f'{output_path}/M{new_name}.npz')
    # joints, pelvis_delta = calc_joints_pelvis_delta(motion_data)
    # motion_data['joints'] = joints
    # motion_data['pelvis_delta'] = pelvis_delta
    # seq_data_dict = {'motion': motion_data, 'data_source': 'babel', 'seq_name': f'M{new_name}', 'feat_p': feat_p}
    # text_path = text_dir / f'M{new_name}.txt'
    # with open(text_path, 'r') as f:
    #     texts = f.readlines()
    # frame_labels = []
    # for text in texts:
    #     frame_labels.append({
    #         'proc_label': text.split('#')[0],
    #         'start_t': 0,
    #         'end_t': motion_data['trans'].shape[0] / target_fps
    #     })
    # seq_data_dict['frame_labels'] = frame_labels
    # dataset[f'M{new_name}'] = seq_data_dict

    # break


with open(ospj(output_path, 'all.pkl'), 'wb') as f:
    pickle.dump(dataset, f)

# with open(ospj(output_path, 'all.pkl'), 'rb') as f:
#     dataset = pickle.load(f)

for split in splits:
    split_data = []
    for seq_name in splits[split]:
        seq_name = seq_name.strip()
        if seq_name in dataset:
            split_data.append(dataset[seq_name])
    with open(ospj(output_path, f'{split}.pkl'), 'wb') as f:
        pickle.dump(split_data, f)

