# extract humanml3d dataset using smpl-h model
import sys
sys.path.insert(0, sys.path[0]+r"/../")

import pdb

import numpy
from pathlib import Path
import pickle
import os
import numpy as np
import json
from os.path import join as ospj
from config_files.data_paths import *
from utils.misc_util import have_overlap
from utils.smpl_utils import *
from tqdm import tqdm
import time
import smplx
import torch
import pickle
import trimesh
import pyrender
from pytorch3d import transforms
from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Slerp
import pandas as pd

def calc_pelvis_delta(motion_data):
    pelvis_delta = primitive_utility.calc_calibrate_offset({
        'betas': torch.tensor(motion_data['betas'], device=device).reshape(1, 10),
        'gender': motion_data['gender'],
    })  # [1, 3]
    pelvis_delta = pelvis_delta.detach().cpu().numpy().squeeze()  # [3]
    return pelvis_delta

def downsample(fps, target_fps, seq_data):
    old_trans = seq_data['trans']
    old_orient = seq_data['root_orient']
    old_pose_body = seq_data['pose_body'][:, :63].reshape((-1, 21, 3))
    old_poses = np.concatenate([old_orient[:, np.newaxis], old_pose_body], axis=1)
    old_num_frames = len(seq_data['trans'])
    new_num_frames = int((old_num_frames - 1) / fps * target_fps) + 1
    if new_num_frames < 2:
        return None, None
    old_time = np.array(range(old_num_frames)) / fps
    new_time = np.array(range(new_num_frames)) / target_fps
    trans = np.zeros((new_num_frames, 3))
    poses = np.zeros((new_num_frames, 22, 3))
    for i in range(3):  # linear interpolation for translation
        trans[:, i] = np.interp(x=new_time, xp=old_time, fp=old_trans[:, i])
    for joint_idx in range(22):
        slerp = Slerp(times=old_time, rotations=R.from_rotvec(old_poses[:, joint_idx, :]))
        poses[:, joint_idx, :] = slerp(new_time).as_rotvec()
    return trans, poses[:,0].reshape((-1, 3)), poses[:,1:].reshape((-1, 63))

def swap_left_right_position(data):
    assert len(data.shape) == 3 and data.shape[-1] == 3
    data = data.copy()
    data[..., 0] *= -1
    right_chain = [2, 5, 8, 11, 14, 17, 19, 21]
    left_chain = [1, 4, 7, 10, 13, 16, 18, 20]
    left_hand_chain = [22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30, 52, 53, 54, 55, 56]
    right_hand_chain = [43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51, 57, 58, 59, 60, 61]

    tmp = data[:, right_chain]
    data[:, right_chain] = data[:, left_chain]
    data[:, left_chain] = tmp
    if data.shape[1] > 24:
        tmp = data[:, right_hand_chain]
        data[:, right_hand_chain] = data[:, left_hand_chain]
        data[:, left_hand_chain] = tmp
    return data

def swap_left_right_rot(data):
    assert len(data.shape) == 3 and data.shape[-1] == 6
    data = data.copy()

    data[..., [1,2,4]] *= -1

    right_chain = np.array([2, 5, 8, 11, 14, 17, 19, 21])-1
    left_chain = np.array([1, 4, 7, 10, 13, 16, 18, 20])-1
    left_hand_chain = np.array([22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30,])-1
    right_hand_chain = np.array([43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51,])-1

    tmp = data[:, right_chain]
    data[:, right_chain] = data[:, left_chain]
    data[:, left_chain] = tmp
    if data.shape[1] > 24:
        tmp = data[:, right_hand_chain]
        data[:, right_hand_chain] = data[:, left_hand_chain]
        data[:, left_hand_chain] = tmp
    return data


def swap_left_right(data, n_joints):
    T = data.shape[0]
    new_data = data.copy()
    positions = new_data[..., :3*n_joints].reshape(T, n_joints, 3)
    rotations = new_data[..., 3*n_joints:].reshape(T, -1, 6)

    positions = swap_left_right_position(positions)
    rotations = swap_left_right_rot(rotations)

    new_data = np.concatenate([positions.reshape(T, -1), rotations.reshape(T, -1)], axis=-1)
    return new_data

def mirror_global_orient(global_orient):
    quat = R.from_rotvec(global_orient).as_quat()  # [B, 4], (x, y, z, w)
    quat[:, 2] *= -1  # mirror along the yz-plane
    mirrored_orient = R.from_quat(quat).as_rotvec().astype(np.float32)
    return mirrored_orient  # [B, 3]

def mirror_sequence(trans, poses):
    """ left-right mirror of translation and joint rotations"""
    trans_mirror = deepcopy(trans)
    trans_mirror[:, 0] = -trans_mirror[:, 0]
    right_chain = [2, 5, 8, 11, 14, 17, 19, 21]
    left_chain = [1, 4, 7, 10, 13, 16, 18, 20]

    # mirror joint rotations
    def mirror_orient(poses, right_chain, left_chain):
        poses_mirror = deepcopy(poses).reshape((-1, 3))
        poses_quaternion = R.from_rotvec(poses_mirror).as_quat()
        poses_quaternion *= np.array([1, 1, -1, 1], dtype=np.float32).reshape((1, 4))
        poses_mirror = R.from_quat(poses_quaternion).as_rotvec().astype(np.float32)
        poses_mirror = poses_mirror.reshape((-1, 22, 3))
        poses_mirror[right_chain], poses_mirror[left_chain] = poses_mirror[left_chain], poses_mirror[right_chain]
        poses_mirror = poses_mirror.reshape((-1, 66))
        return poses_mirror

    poses_mirror = mirror_orient(poses, right_chain, left_chain)
    poses_mirror = poses_mirror.reshape((-1, 66))
    # pdb.set_trace()

    return trans_mirror, poses_mirror


model_path = body_model_dir
gender = "male"
device = 'cuda:3'
primitive_utility = PrimitiveUtility(device=device, body_type='smplh')
enforce_zero_male=False

splits = {}
with open('./data/InterHuman/split/train.txt', 'r') as f:
    splits['train'] = f.readlines()
    splits['train'] = [item.strip() for item in splits['train']]
with open('./data/InterHuman/split/val.txt', 'r') as f:
    splits['val'] = f.readlines()
    splits['val'] = [item.strip() for item in splits['val']]
with open('./data/InterHuman/split/test.txt', 'r') as f:
    splits['test'] = f.readlines()
    splits['test'] = [item.strip() for item in splits['test']]

# load interhuman data
dataset_path = interhuman_dir / 'motions_processed/'
raw_dataset_path = interhuman_dir / 'motions/'
data_list = {}
for person in ['person1', 'person2']:
    data_list[person] = sorted(
        os.listdir(dataset_path / person),
        key=lambda x: int(os.path.splitext(x)[0])
    )
total_amount = len(data_list['person1'])

text1_dir = interhuman_dir / 'separate_annots/text1/'
text2_dir = interhuman_dir / 'separate_annots/text2/'
text_dir = interhuman_dir / 'annots/' 

output_path = f'./data/InterHuman/seq_data_single_interaction_d262'

target_fps = 30
mirror_data = True
exchange_yz = True
output_path = f'{output_path}_fps{target_fps}'
if mirror_data:
    output_path = f'{output_path}_mirror'
if exchange_yz:
    output_path = f'{output_path}_exchangeyz'
Path(output_path).mkdir(exist_ok=True, parents=True)

dataset = {}
fps = 60.0
downsample_rate = int(fps / target_fps)

def downsample_with_last_frame(data, rate):
    idxs = list(range(0, len(data), rate))
    if idxs[-1] != len(data) - 1:
        idxs.append(len(data) - 1)
    return data[idxs]

def load_dual_human_sequence(i, mirror=False):
    motion_data = {}
    for person in ['person1', 'person2']:
        source_path = data_list[person][i]
        raw_path = source_path.replace('.npy', '.pkl')
        idx = source_path.split('.')[0]
        new_name = source_path.split('.')[0]
        if mirror:
            new_name = f"{new_name}_mirror"
        seq_path = os.path.join(dataset_path, person, source_path)
        raw_path = os.path.join(raw_dataset_path, raw_path)

        if not os.path.exists(seq_path):
            print(f"seq_path not found: {seq_path}")
            continue

        seq_data = np.load(seq_path).astype(np.float32)
        with open(raw_path, 'rb') as f:
            raw_data = pickle.load(f)

        motion_data[person] = {
            'gender': raw_data[person]['gender'],
            'betas': raw_data[person]['betas'][:10].astype(np.float32),
        }
        pelvis_delta = calc_pelvis_delta(motion_data[person])
        
        joints = seq_data[:, :22 * 3]                         # [num_frames, 22*3]
        pose_body = seq_data[:, 62 * 3:62 * 3 + 21 * 6]       # [num_frames, 21*6]
        if mirror:
            seq_data = swap_left_right(np.concatenate([joints, pose_body],axis=-1), 22)
            joints = seq_data[:, :22 * 3]    # [num_frames, 22, 3]
            pose_body = seq_data[:, 22 * 3:] # [num_frames, 21, 6]
        trans = downsample_with_last_frame(raw_data[person]['trans'], downsample_rate).astype(np.float32)
        global_orient = downsample_with_last_frame(raw_data[person]['root_orient'], downsample_rate).astype(np.float32)
        
        # put on the floor
        joints = joints.reshape(-1, 22, 3)
        floor_height = joints.min(axis=0).min(axis=0)[2]
        joints[..., :2] -= floor_height
        trans[..., 2] -= floor_height
        
        if mirror:
            global_orient = mirror_global_orient(global_orient)
        
        # exchange y and z axis
        if exchange_yz:
            trans_matrix = np.array([
                [1.0, 0.0, 0.0],
                [0.0, 0.0, 1.0],
                [0.0, -1.0, 0.0]
            ], dtype=joints.dtype)
            joints = np.einsum('ij,tbj->tbi', trans_matrix, joints)                                      
            trans = np.einsum('ij,tj->ti', trans_matrix, trans + pelvis_delta) - pelvis_delta
            r_old = R.from_rotvec(global_orient).as_matrix()
            r_new = np.einsum('ij,tbj->tbi', trans_matrix, r_old)
            global_orient = R.from_matrix(r_new).as_rotvec().astype(np.float32)
            
        joints = joints.reshape(-1, 22 * 3)
        
        cut_dim = trans.shape[0]

        motion_data[person].update({
            'pose_body': pose_body[:cut_dim],
            'trans': trans,
            'global_orient': global_orient,
            'joints': joints[:cut_dim],
            'pelvis_delta': pelvis_delta,
        })

    seq_data_dict = {
        'motion_p1': motion_data['person1'],
        'motion_p2': motion_data['person2'],
        'data_source': 'interhuman',
        'seq_name': new_name
    }

    text1_path = Path(text1_dir) / f'{idx}.txt'
    text2_path = Path(text2_dir) / f'{idx}.txt'
    text_path  = Path(text_dir)  / f'{idx}.txt'

    with open(text1_path, 'r') as f:
        texts_p1 = f.readlines()
        if mirror:
            texts_p1 = [item.replace("left", "tmp").replace("right", "left").replace("tmp", "right")
                                  .replace("clockwise", "tmp").replace("counterclockwise","clockwise").replace("tmp","counterclockwise") for item in texts_p1]

    with open(text2_path, 'r') as f:
        texts_p2 = f.readlines()
        if mirror:
            texts_p2 = [item.replace("left", "tmp").replace("right", "left").replace("tmp", "right")
                                  .replace("clockwise", "tmp").replace("counterclockwise","clockwise").replace("tmp","counterclockwise") for item in texts_p2]

    with open(text_path, 'r') as f:
        texts = f.readlines()
        if mirror:
            texts = [item.replace("left", "tmp").replace("right", "left").replace("tmp", "right")
                                  .replace("clockwise", "tmp").replace("counterclockwise","clockwise").replace("tmp","counterclockwise") for item in texts]

    frame_labels_p1 = [{
        'proc_label': text,
        'start_t': 0.0,
        'end_t': motion_data['person1']['trans'].shape[0] / target_fps
    } for text in texts_p1]

    frame_labels_p2 = [{
        'proc_label': text,
        'start_t': 0.0,
        'end_t': motion_data['person2']['trans'].shape[0] / target_fps
    } for text in texts_p2]

    frame_labels = [{
        'proc_label': text,
        'start_t': 0.0,
        'end_t': motion_data['person1']['trans'].shape[0] / target_fps
    } for text in texts]

    seq_data_dict['frame_labels_person1'] = frame_labels_p1
    seq_data_dict['frame_labels_person2'] = frame_labels_p2
    seq_data_dict['frame_labels_interaction'] = frame_labels

    return seq_data_dict, new_name

for i in tqdm(range(total_amount)):
    base_name = os.path.splitext(data_list['person1'][i])[0]
    is_train = base_name in splits['train']
    seq_data_dict, new_name = load_dual_human_sequence(i, mirror=False)
    dataset[new_name] = seq_data_dict
    if mirror_data and is_train:
        seq_data_dict_mirror, new_name_mirror = load_dual_human_sequence(i, mirror=True)
        dataset[new_name_mirror] = seq_data_dict_mirror
    


with open(ospj(output_path, 'all.pkl'), 'wb') as f:
    pickle.dump(dataset, f)

for split in splits:
    split_data = []
    for seq_name in splits[split]:
        # seq_name = seq_name.strip()
        if seq_name in dataset:
            split_data.append(dataset[seq_name])
            if mirror_data and split == 'train':
                seq_name_mirror = f"{seq_name}_mirror"
                if seq_name_mirror in dataset:
                    split_data.append(dataset[seq_name_mirror])
    with open(ospj(output_path, f'{split}.pkl'), 'wb') as f:
        pickle.dump(split_data, f)

