import sys
sys.path.append(sys.path[0]+"/../")
sys.path.append(sys.path[0]+"/../../")
import torch
import numpy as np
import argparse
import pickle
import smplx
import os

from visualization.utils import bvh, quat
from HHInter.global_path import *


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default=get_SMPL_SMPLH_SMPLX_body_model_path())
    parser.add_argument("--model_type", type=str, default="smpl", choices=["smpl", "smplx"])
    parser.add_argument("--gender", type=str, default="MALE", choices=["MALE", "FEMALE", "NEUTRAL"])
    parser.add_argument("--num_betas", type=int, default=10, choices=[10, 300])
    parser.add_argument("--poses", type=str, default=os.path.join(get_program_root_path(), "Sitcom-Crafter/HSInter/results_Story_HIM/smplh-bvh"))
    parser.add_argument("--fps", type=int, default=30)
    parser.add_argument("--output", type=str, default=os.path.join(get_program_root_path(), "Sitcom-Crafter/HSInter/results_Story_HIM/smplh-bvh"))
    parser.add_argument("--mirror", action="store_true")
    return parser.parse_args()

def mirror_rot_trans(lrot, trans, names, parents):
    joints_mirror = np.array([(
        names.index("Left"+n[5:]) if n.startswith("Right") else (
        names.index("Right"+n[4:]) if n.startswith("Left") else 
        names.index(n))) for n in names])

    mirror_pos = np.array([-1, 1, 1])
    mirror_rot = np.array([1, 1, -1, -1])
    grot = quat.fk_rot(lrot, parents)
    trans_mirror = mirror_pos * trans
    grot_mirror = mirror_rot * grot[:,joints_mirror]
    
    return quat.ik_rot(grot_mirror, parents), trans_mirror

def smpl2bvh(model_path:str, poses:str, output:str, mirror:bool,
             model_type="smpl", gender="MALE",
             num_betas=10, fps=60) -> None:
    """Save bvh file created by smpl parameters.

    Args:
        model_path (str): Path to smpl models.
        poses (str): Path to npz or pkl file.
        output (str): Where to save bvh.
        mirror (bool): Whether save mirror motion or not.
        model_type (str, optional): I prepared "smpl" only. Defaults to "smpl".
        gender (str, optional): Gender Information. Defaults to "MALE".
        num_betas (int, optional): How many pca parameters to use in SMPL. Defaults to 10.
        fps (int, optional): Frame per second. Defaults to 30.
    """
    
    # names = [
    #     "Pelvis",
    #     "Left_hip",
    #     "Right_hip",
    #     "Spine1",
    #     "Left_knee",
    #     "Right_knee",
    #     "Spine2",
    #     "Left_ankle",
    #     "Right_ankle",
    #     "Spine3",
    #     "Left_foot",
    #     "Right_foot",
    #     "Neck",
    #     "Left_collar",
    #     "Right_collar",
    #     "Head",
    #     "Left_shoulder",
    #     "Right_shoulder",
    #     "Left_elbow",
    #     "Right_elbow",
    #     "Left_wrist",
    #     "Right_wrist",
    #     "Left_palm",
    #     "Right_palm",
    # ]

    names = [
        "Hips",
        "LeftUpLeg",
        "RightUpLeg",
        "Spine",
        "LeftLeg",
        "RightLeg",
        "Spine1",
        "LeftFoot",
        "RightFoot",
        "Spine2",
        "LeftToe",
        "RightToe",
        "Neck",
        "LeftShoulder",
        "RightShoulder",
        "Head",
        "LeftArm",
        "RightArm",
        "LeftForeArm",
        "RightForeArm",
        "LeftHand",
        "RightHand",
        "LeftThumb",
        "RightThumb",
    ]

    # I prepared smpl models only, 
    # but I will release for smplx models recently.
    model = smplx.create(model_path=model_path, 
                        model_type=model_type,
                        gender=gender, 
                        batch_size=1)
    
    parents = model.parents.detach().cpu().numpy()

    scaling = None

    # Pose setting.
    if poses.endswith(".npz"):
        poses = np.load(poses)
        rots = np.squeeze(poses["poses"], axis=0)  # (N, 24, 3)
        trans = np.squeeze(poses["trans"], axis=0)  # (N, 3)

    elif poses.endswith(".pkl"):
        with open(poses, "rb") as f:
            poses = pickle.load(f)

            rots = poses[:, 3:75].reshape(-1, 24, 3)  # (N, 72)
            trans = poses[:, :3]  # (N, 3)
            betas = torch.from_numpy(poses[:, -10:]).type(torch.float32)  # (N, 10)

    else:
        raise Exception("This file type is not supported!")

    # You can define betas like this.(default betas are 0 at all.)
    rest = model(
        betas = betas[[0]]
    )
    rest_pose = rest.joints.detach().cpu().numpy().squeeze()[:24,:]
    
    root_offset = rest_pose[0]
    offsets = rest_pose - rest_pose[parents]
    offsets[0] = root_offset
    offsets *= 1
    
    if scaling is not None:
        trans /= scaling
    
    # to quaternion
    rots = quat.from_axis_angle(rots)
    
    order = "zyx"
    pos = offsets[None].repeat(len(rots), axis=0)
    positions = pos.copy()
    # positions[:,0] += trans * 10
    positions[:, 0] += trans
    rotations = np.degrees(quat.to_euler(rots, order=order))
    
    bvh_data ={
        "rotations": rotations[:, :22],
        "positions": positions[:, :22],
        "offsets": offsets[:22],
        "parents": parents[:22],
        "names": names[:22],
        "order": order,
        "frametime": 1 / fps,
    }
    
    if not output.endswith(".bvh"):
        output = output + ".bvh"
    
    bvh.save(output, bvh_data)
    
    if mirror:
        rots_mirror, trans_mirror = mirror_rot_trans(
                rots, trans, names, parents)
        positions_mirror = pos.copy()
        positions_mirror[:,0] += trans_mirror
        rotations_mirror = np.degrees(
            quat.to_euler(rots_mirror, order=order))
        
        bvh_data ={
            "rotations": rotations_mirror,
            "positions": positions_mirror,
            "offsets": offsets,
            "parents": parents,
            "names": names,
            "order": order,
            "frametime": 1 / fps,
        }
        
        output_mirror = output.split(".")[0] + "_mirror.bvh"
        bvh.save(output_mirror, bvh_data)


if __name__ == "__main__":
    args = parse_args()

    for i in [1, 2]:
        poses = f"{args.poses}/person{i}.pkl"
        output = f"{args.output}/person{i}.bvh"
        smpl2bvh(model_path=args.model_path, model_type=args.model_type,
                 mirror = args.mirror, gender=args.gender,
                 poses=poses, num_betas=args.num_betas,
                 fps=args.fps, output=output)
    
    print("finished!")