import sys
sys.path.insert(0, sys.path[0]+r"/../")

import pickle
import numpy as np
import torch
import copy
from pytorch3d import transforms

# single_stand_path = './data/stand_20fps.pkl'
# out_stand_path = './data/stand_20fps_interaction_similar.pkl'
single_stand_path = './data/t_pose.pkl'
out_stand_path = './data/t_pose_interaction_v3.pkl'

with open(single_stand_path, 'rb') as f:
    single_stand = pickle.load(f)

transl_partner = copy.deepcopy(single_stand['transl'])
# transl_partner[:, 1] -= 3
transl_partner[:, 0] += 1
global_orient_partner = copy.deepcopy(single_stand['global_orient'])
global_orient_partner = transforms.axis_angle_to_matrix(torch.tensor(global_orient_partner))
R_z_180 = torch.tensor([
    [-1,  0,  0],
    [ 0, -1,  0],
    [ 0,  0,  1]
], dtype=global_orient_partner.dtype)
# R_z_180 = torch.tensor([
#     [1, 0, 0],
#     [0, 1, 0],
#     [0, 0, 1]
# ], dtype=global_orient_partner.dtype)
global_orient_partner = R_z_180@global_orient_partner
# global_orient_partner[..., 0] *= -1
# global_orient_partner[..., 1] *= -1
# global_orient_partner[..., 2] *= -1
global_orient_partner = transforms.matrix_to_axis_angle(global_orient_partner).detach().cpu().numpy()
partner = {'transl': transl_partner, 'global_orient': global_orient_partner}

for key in single_stand.keys():
    if key in ['transl', 'global_orient']:
        single_stand[key] = np.concatenate((single_stand[key], partner[key]), axis=-1)
    elif key == 'gender':
        single_stand[key] = [single_stand[key], single_stand[key]]
    elif key == 'text':
        continue
    else:
        single_stand[key] = np.concatenate((single_stand[key], single_stand[key]), axis=-1)

with open(out_stand_path, 'wb') as f:
    pickle.dump(single_stand, f)
print(f"Saved to {out_stand_path}")