import os
import uuid
from pathlib import Path
from typing import Optional

import ipdb
import yaml
import numpy as np
import torch
import typer
from scipy.spatial.transform import Rotation as sRot
import pickle
from smpl_sim.smpllib.smpl_joint_names import (
    SMPL_BONE_ORDER_NAMES,
    SMPL_MUJOCO_NAMES,
    SMPLH_BONE_ORDER_NAMES,
    SMPLH_MUJOCO_NAMES,
)
from smpl_sim.smpllib.smpl_local_robot import SMPL_Robot
from tqdm import tqdm

from poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState, SkeletonTree
import time
from datetime import timedelta

TMP_SMPL_DIR = "/tmp/smpl"


def main(
    amass_root_dir: Path,
    robot_type: str = 'g1',
    humanoid_type: str = "smpl",
    force_remake: bool = False,
    force_neutral_body: bool = True,
    upright_start: bool = True,  # By default, let's start upright (for consistency across all models).
    # humanoid_mjcf_path: Optional[str] = None,
    humanoid_mjcf_path: Optional[str] = "data/assets/mjcf/smpl_humanoid.xml",
    force_retarget: bool = True,
):
    if robot_type is None:
        robot_type = humanoid_type
    elif robot_type in ["h1", "g1"]:
        assert (
            force_retarget
        ), f"Data is either SMPL or SMPL-X. The {robot_type} robot must use the retargeting pipeline."

    assert humanoid_type in [
        "smpl",
        "smplx",
        "smplh",
    ], "Humanoid type must be one of smpl, smplx, smplh"


    if humanoid_type == "smpl":
        mujoco_joint_names = SMPL_MUJOCO_NAMES
        joint_names = SMPL_BONE_ORDER_NAMES
    elif humanoid_type == "smplx" or humanoid_type == "smplh":
        mujoco_joint_names = SMPLH_MUJOCO_NAMES
        joint_names = SMPLH_BONE_ORDER_NAMES
    else:
        raise NotImplementedError
    
    # construct smpl ske_tree
    if humanoid_mjcf_path is not None:
        skeleton_tree = SkeletonTree.from_mjcf(humanoid_mjcf_path)
        print("skeleton_tree_parents: ", skeleton_tree.parent_indices)
    else:
        skeleton_tree = None
    
    # mkdir
    append_name = robot_type
    if force_retarget:
        append_name += "_retargeted_npy"
    folder_names = [
        f.path.split("/")[-1] for f in os.scandir(amass_root_dir) if f.is_dir()
    ]

    # Count total number of files that need processing
    start_time = time.time()
    total_files = 0
    total_files_to_process = 0
    processed_files = 0
    for folder_name in folder_names:
        if "retarget" in folder_name or "smpl" in folder_name or "h1" in folder_name:
            continue
        data_dir = amass_root_dir / folder_name
        output_dir = amass_root_dir / f"{folder_name}-{append_name}"

        all_files_in_folder = [
            f
            for f in Path(data_dir).glob("**/*.[np][pk][lz]")
            if (f.name != "shape.npz" and "stagei.npz" not in f.name)
        ]

        if not force_remake:
            # Only count files that don't already have outputs
            files_to_process = [
                f
                for f in all_files_in_folder
                if not (
                    output_dir
                    / f.relative_to(data_dir).parent
                    / f.name.replace(".npz", ".npy")
                    .replace(".pkl", ".npy")
                    .replace("-", "_")
                    .replace(" ", "_")
                    .replace("(", "_")
                    .replace(")", "_")
                ).exists()
            ]
        else:
            files_to_process = all_files_in_folder
        print(
            f"Processing {len(files_to_process)}/{len(all_files_in_folder)} files in {folder_name}"
        )
        total_files_to_process += len(files_to_process)
        total_files += len(all_files_in_folder)

    print(f"Total files to process: {total_files_to_process}/{total_files}")

    for folder_name in folder_names:
        if "retarget" in folder_name or "smpl" in folder_name or "h1" in folder_name:
            # Ignore folders where we store motions retargeted to AMP
            continue

        data_dir = amass_root_dir / folder_name
        output_dir = amass_root_dir / f"{folder_name}-{append_name}"

        print(f"Processing subset {folder_name}")
        os.makedirs(output_dir, exist_ok=True)

        files = [
            f
            for f in Path(data_dir).glob("**/*.[np][pk][lz]")
            if (f.name != "shape.npz" and "stagei.npz" not in f.name)
        ]
        print(f"Processing {len(files)} files")
        files.sort()
        # read data --> mink_retarget --> save data
        for filename in tqdm(files):
                relative_path_dir = filename.relative_to(data_dir).parent
                outpath = (
                    output_dir
                    / relative_path_dir
                    / filename.name.replace(".npz", ".npy")
                    .replace(".pkl", ".npy")
                    .replace("-", "_")
                    .replace(" ", "_")
                    .replace("(", "_")
                    .replace(")", "_")
                )

                # Check if the output file already exists
                if not force_remake and outpath.exists():
                    # print(f"Skipping {filename} as it already exists.")
                    continue

                # Create the output directory if it doesn't exist
                os.makedirs(output_dir / relative_path_dir, exist_ok=True)

                print(f"Processing {filename}")
                if filename.suffix == ".npz" and "samp" not in str(filename):
                    motion_data = np.load(filename)

                    betas = motion_data["betas"]
                    gender = motion_data["gender"]
                    amass_pose = motion_data["poses"]
                    amass_trans = motion_data["trans"]
                    if humanoid_type == "smplx":
                        # Load the fps from the yaml file
                        fps_yaml_path = Path("data/yaml_files/motion_fps_amassx.yaml")
                        with open(fps_yaml_path, "r") as f:
                            fps_dict = yaml.safe_load(f)

                        # Convert filename to match yaml format
                        yaml_key = (
                            folder_name
                            + "/"
                            + str(
                                relative_path_dir
                                / filename.name.replace(".npz", ".npy")
                                .replace("-", "_")
                                .replace(" ", "_")
                                .replace("(", "_")
                                .replace(")", "_")
                            )
                        )

                        if yaml_key in fps_dict:
                            mocap_fr = fps_dict[yaml_key]
                        elif "mocap_framerate" in motion_data:
                            mocap_fr = motion_data["mocap_framerate"]
                        elif "mocap_frame_rate" in motion_data:
                            mocap_fr = motion_data["mocap_frame_rate"]
                        else:
                            raise Exception(f"FPS not found for {yaml_key}")
                        print(f"FPS: {mocap_fr}")
                    else:
                        if "mocap_framerate" in motion_data:
                            mocap_fr = motion_data["mocap_framerate"]
                        else:
                            mocap_fr = motion_data["mocap_frame_rate"]
                elif filename.suffix == ".pkl" and "samp" in str(filename):
                    with open(filename, "rb") as f:
                        motion_data = pickle.load(
                            f, encoding="latin1"
                        )  # np.load(filename)

                    betas = motion_data["shape_est_betas"][:10]
                    gender = "neutral"  # motion_data["gender"]
                    amass_pose = motion_data["pose_est_fullposes"]
                    amass_trans = motion_data["pose_est_trans"]
                    mocap_fr = motion_data["mocap_framerate"]
                else:
                    print(f"Skipping {filename} as it is not a valid file")
                    continue

                pose_aa = torch.tensor(amass_pose)
                amass_trans = torch.tensor(amass_trans)
                betas = torch.from_numpy(betas)

                if force_neutral_body:
                    betas[:] = 0
                    gender = "neutral"

                motion_data = {
                    "pose_aa": pose_aa.numpy(),
                    "trans": amass_trans.numpy(),
                    "beta": betas.numpy(),
                    "gender": gender,
                }

                # smpl 2 mujoco(mink)
                # rot 2 quat
                smpl_2_mujoco = [
                    joint_names.index(q) for q in mujoco_joint_names if q in joint_names
                ]

                batch_size = motion_data["pose_aa"].shape[0]

                pose_aa = np.concatenate(
                    [motion_data["pose_aa"][:, :66], np.zeros((batch_size, 6))],
                    axis=1,
                )  # TODO: need to extract correct handle rotations instead of zero

                pose_aa_walk = torch.from_numpy(pose_aa).float()
                root_trans = torch.from_numpy(motion_data["trans"])

                pose_aa_mj = pose_aa.reshape(batch_size, 24, 3)[:, smpl_2_mujoco]
                pose_quat = (
                    sRot.from_rotvec(pose_aa_mj.reshape(-1, 3))
                    .as_quat()
                    .reshape(batch_size, 24, 4)
                )

                # fit shape
                from smpl_sim.smpllib.smpl_parser import (
                    SMPL_Parser,
                    SMPLH_Parser,
                    SMPLX_Parser, 
                )
                import joblib
                smpl_parser_n = SMPL_Parser(model_path="data/smpl", gender="neutral")
                print("smpl_parser_n: ", smpl_parser_n)
                shape_new, scale = joblib.load(f"./data/scripts/shape_optimized_neutral.pkl")
                print("shape_new: ", shape_new)
                print("scale: ", scale)

                with torch.no_grad():
                    verts, joints = smpl_parser_n.get_joints_verts(pose_aa_walk, shape_new, root_trans)
                    root_pos = joints[:, 0:1]
                    joints = (joints - joints[:, 0:1]) * scale.detach() + root_pos
                joints[..., 2] -= verts[0, :, 2].min().item()
                root_pos = joints[:, 0]


                global_trans = joints[:, smpl_2_mujoco]
                pose_aa_walk = pose_aa.reshape(batch_size, 24, 3)[:, smpl_2_mujoco]
                pose_walk_quat = (
                    sRot.from_rotvec(pose_aa_walk.reshape(-1, 3))
                    .as_quat()
                    .reshape(batch_size, 24, 4)
                )               

                # use parent relationship to get global rotation
                sk_state = SkeletonState.from_rotation_and_root_translation(
                    skeleton_tree,  
                    torch.from_numpy(pose_walk_quat),
                    root_pos,
                    is_local=True,
                )

                # upright start
                if upright_start:
                    B = pose_aa.shape[0]
                    pose_quat_global = (
                        (
                            sRot.from_quat(
                                sk_state.global_rotation.reshape(-1, 4).numpy()
                            )
                            * sRot.from_quat([0.5, 0.5, 0.5, 0.5]).inv()
                        )
                        .as_quat()
                        .reshape(B, -1, 4)
                    )
                else:
                    pose_quat_global = sk_state.global_rotation.numpy()


                if force_retarget:
                    from retargeting.mink_retarget import (
                        retarget_fit_motion
                    )

                    print("Force retargeting motion using mink retargeter...")
                    # Convert to 30 fps to speedup Mink retargeting
                    skip = int(mocap_fr // 30)

                    fps = 30
                    new_sk_motion = retarget_fit_motion(
                        global_trans[::skip], pose_quat_global[::skip], fps, robot_type=robot_type, render=False) 

                    print(f"Saving to {outpath}")

                    # save mujoco vis data
                    '''
                    dict_keys(['global_translation', 'global_rotation_mat', 'global_rotation', 
                               'global_velocity', 'global_angular_velocity', 'local_rotation', 'global_root_velocity', 
                               'global_root_angular_velocity', 'dof_pos', 'dof_vels', 'fps'])
                    '''
                    motion_data = {
                                'root_trans_offset': new_sk_motion['global_translation'][:,0,:],
                                'root_rot': new_sk_motion['global_rotation'][:,0,:],
                                'dof': new_sk_motion['dof_pos'],
                                'fps': new_sk_motion['fps'],
                            }
                    motion_data = {
                        k: np.array(v) for k, v in motion_data.items()
                    }

                    output_folder_path = amass_root_dir / f"retargeted_motion_pkl"

                    # os.makedirs(path, exist_ok=True)

                    os.makedirs(output_folder_path, exist_ok=True)
                    path = os.path.join(output_folder_path, f"{filename.stem}.pkl")

                    print(path)

                    data = {filename: motion_data}
                    with open((path), 'wb') as f:
                        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
                    
                    if robot_type in ["h1", "g1"]:
                        torch.save(new_sk_motion, str(outpath))
                    else:
                        new_sk_motion.to_file(str(outpath))

                    processed_files += 1
                    elapsed_time = time.time() - start_time
                    avg_time_per_file = elapsed_time / processed_files
                    remaining_files = total_files_to_process - processed_files
                    estimated_time_remaining = avg_time_per_file * remaining_files

                    print(
                        f"\nProgress: {processed_files}/{total_files_to_process} files"
                    )
                    print(
                        f"Average time per file: {timedelta(seconds=int(avg_time_per_file))}"
                    )
                    print(
                        f"Estimated time remaining: {timedelta(seconds=int(estimated_time_remaining))}"
                    )
                    print(
                        f"Estimated completion time: {time.strftime('%H:%M:%S', time.localtime(time.time() + estimated_time_remaining))}\n"
                    )


if __name__ == "__main__":
    with torch.no_grad():
        typer.run(main)
