"""
original from https://github.com/vchoutas/smplx
modified by Vassilis and Yao
"""

import torch
import torch.nn as nn
import numpy as np
import pickle

from .lbs import (
    Struct,
    to_tensor,
    to_np,
    lbs,
    vertices2landmarks,
    JointsFromVerticesSelector,
    find_dynamic_lmk_idx_and_bcoords,
)

# SMPLX
J14_NAMES = [
    "right_ankle",
    "right_knee",
    "right_hip",
    "left_hip",
    "left_knee",
    "left_ankle",
    "right_wrist",
    "right_elbow",
    "right_shoulder",
    "left_shoulder",
    "left_elbow",
    "left_wrist",
    "neck",
    "head",
]
SMPLX_names = [
    "pelvis",
    "left_hip",
    "right_hip",
    "spine1",
    "left_knee",
    "right_knee",
    "spine2",
    "left_ankle",
    "right_ankle",
    "spine3",
    "left_foot",
    "right_foot",
    "neck",
    "left_collar",
    "right_collar",
    "head",
    "left_shoulder",
    "right_shoulder",
    "left_elbow",
    "right_elbow",
    "left_wrist",
    "right_wrist",
    "jaw",
    "left_eye_smplx",
    "right_eye_smplx",
    "left_index1",
    "left_index2",
    "left_index3",
    "left_middle1",
    "left_middle2",
    "left_middle3",
    "left_pinky1",
    "left_pinky2",
    "left_pinky3",
    "left_ring1",
    "left_ring2",
    "left_ring3",
    "left_thumb1",
    "left_thumb2",
    "left_thumb3",
    "right_index1",
    "right_index2",
    "right_index3",
    "right_middle1",
    "right_middle2",
    "right_middle3",
    "right_pinky1",
    "right_pinky2",
    "right_pinky3",
    "right_ring1",
    "right_ring2",
    "right_ring3",
    "right_thumb1",
    "right_thumb2",
    "right_thumb3",
    "right_eye_brow1",
    "right_eye_brow2",
    "right_eye_brow3",
    "right_eye_brow4",
    "right_eye_brow5",
    "left_eye_brow5",
    "left_eye_brow4",
    "left_eye_brow3",
    "left_eye_brow2",
    "left_eye_brow1",
    "nose1",
    "nose2",
    "nose3",
    "nose4",
    "right_nose_2",
    "right_nose_1",
    "nose_middle",
    "left_nose_1",
    "left_nose_2",
    "right_eye1",
    "right_eye2",
    "right_eye3",
    "right_eye4",
    "right_eye5",
    "right_eye6",
    "left_eye4",
    "left_eye3",
    "left_eye2",
    "left_eye1",
    "left_eye6",
    "left_eye5",
    "right_mouth_1",
    "right_mouth_2",
    "right_mouth_3",
    "mouth_top",
    "left_mouth_3",
    "left_mouth_2",
    "left_mouth_1",
    "left_mouth_5",
    "left_mouth_4",
    "mouth_bottom",
    "right_mouth_4",
    "right_mouth_5",
    "right_lip_1",
    "right_lip_2",
    "lip_top",
    "left_lip_2",
    "left_lip_1",
    "left_lip_3",
    "lip_bottom",
    "right_lip_3",
    "right_contour_1",
    "right_contour_2",
    "right_contour_3",
    "right_contour_4",
    "right_contour_5",
    "right_contour_6",
    "right_contour_7",
    "right_contour_8",
    "contour_middle",
    "left_contour_8",
    "left_contour_7",
    "left_contour_6",
    "left_contour_5",
    "left_contour_4",
    "left_contour_3",
    "left_contour_2",
    "left_contour_1",
    "head_top",
    "left_big_toe",
    "left_ear",
    "left_eye",
    "left_heel",
    "left_index",
    "left_middle",
    "left_pinky",
    "left_ring",
    "left_small_toe",
    "left_thumb",
    "nose",
    "right_big_toe",
    "right_ear",
    "right_eye",
    "right_heel",
    "right_index",
    "right_middle",
    "right_pinky",
    "right_ring",
    "right_small_toe",
    "right_thumb",
]
extra_names = [
    "head_top",
    "left_big_toe",
    "left_ear",
    "left_eye",
    "left_heel",
    "left_index",
    "left_middle",
    "left_pinky",
    "left_ring",
    "left_small_toe",
    "left_thumb",
    "nose",
    "right_big_toe",
    "right_ear",
    "right_eye",
    "right_heel",
    "right_index",
    "right_middle",
    "right_pinky",
    "right_ring",
    "right_small_toe",
    "right_thumb",
]
SMPLX_names += extra_names

part_indices = {}
part_indices["body"] = np.array(
    [
        0,
        1,
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
        10,
        11,
        12,
        13,
        14,
        15,
        16,
        17,
        18,
        19,
        20,
        21,
        22,
        23,
        24,
        123,
        124,
        125,
        126,
        127,
        132,
        134,
        135,
        136,
        137,
        138,
        143,
    ]
)
part_indices["torso"] = np.array(
    [
        0,
        1,
        2,
        3,
        6,
        9,
        12,
        13,
        14,
        15,
        16,
        17,
        18,
        19,
        22,
        23,
        24,
        55,
        56,
        57,
        58,
        59,
        76,
        77,
        78,
        79,
        80,
        81,
        82,
        83,
        84,
        85,
        86,
        87,
        88,
        89,
        90,
        91,
        92,
        93,
        94,
        95,
        96,
        97,
        98,
        99,
        100,
        101,
        102,
        103,
        104,
        105,
        106,
        107,
        108,
        109,
        110,
        111,
        112,
        113,
        114,
        115,
        116,
        117,
        118,
        119,
        120,
        121,
        122,
        123,
        124,
        125,
        126,
        127,
        128,
        129,
        130,
        131,
        132,
        133,
        134,
        135,
        136,
        137,
        138,
        139,
        140,
        141,
        142,
        143,
        144,
    ]
)
part_indices["head"] = np.array(
    [
        12,
        15,
        22,
        23,
        24,
        55,
        56,
        57,
        58,
        59,
        60,
        61,
        62,
        63,
        64,
        65,
        66,
        67,
        68,
        69,
        70,
        71,
        72,
        73,
        74,
        75,
        76,
        77,
        78,
        79,
        80,
        81,
        82,
        83,
        84,
        85,
        86,
        87,
        88,
        89,
        90,
        91,
        92,
        93,
        94,
        95,
        96,
        97,
        98,
        99,
        100,
        101,
        102,
        103,
        104,
        105,
        106,
        107,
        108,
        109,
        110,
        111,
        112,
        113,
        114,
        115,
        116,
        117,
        118,
        119,
        120,
        121,
        122,
        123,
        125,
        126,
        134,
        136,
        137,
    ]
)
part_indices["face"] = np.array(
    [
        55,
        56,
        57,
        58,
        59,
        60,
        61,
        62,
        63,
        64,
        65,
        66,
        67,
        68,
        69,
        70,
        71,
        72,
        73,
        74,
        75,
        76,
        77,
        78,
        79,
        80,
        81,
        82,
        83,
        84,
        85,
        86,
        87,
        88,
        89,
        90,
        91,
        92,
        93,
        94,
        95,
        96,
        97,
        98,
        99,
        100,
        101,
        102,
        103,
        104,
        105,
        106,
        107,
        108,
        109,
        110,
        111,
        112,
        113,
        114,
        115,
        116,
        117,
        118,
        119,
        120,
        121,
        122,
    ]
)
part_indices["upper"] = np.array(
    [
        12,
        13,
        14,
        55,
        56,
        57,
        58,
        59,
        60,
        61,
        62,
        63,
        64,
        65,
        66,
        67,
        68,
        69,
        70,
        71,
        72,
        73,
        74,
        75,
        76,
        77,
        78,
        79,
        80,
        81,
        82,
        83,
        84,
        85,
        86,
        87,
        88,
        89,
        90,
        91,
        92,
        93,
        94,
        95,
        96,
        97,
        98,
        99,
        100,
        101,
        102,
        103,
        104,
        105,
        106,
        107,
        108,
        109,
        110,
        111,
        112,
        113,
        114,
        115,
        116,
        117,
        118,
        119,
        120,
        121,
        122,
    ]
)
part_indices["hand"] = np.array(
    [
        20,
        21,
        25,
        26,
        27,
        28,
        29,
        30,
        31,
        32,
        33,
        34,
        35,
        36,
        37,
        38,
        39,
        40,
        41,
        42,
        43,
        44,
        45,
        46,
        47,
        48,
        49,
        50,
        51,
        52,
        53,
        54,
        128,
        129,
        130,
        131,
        133,
        139,
        140,
        141,
        142,
        144,
    ]
)
part_indices["left_hand"] = np.array(
    [
        20,
        25,
        26,
        27,
        28,
        29,
        30,
        31,
        32,
        33,
        34,
        35,
        36,
        37,
        38,
        39,
        128,
        129,
        130,
        131,
        133,
    ]
)
part_indices["right_hand"] = np.array(
    [
        21,
        40,
        41,
        42,
        43,
        44,
        45,
        46,
        47,
        48,
        49,
        50,
        51,
        52,
        53,
        54,
        139,
        140,
        141,
        142,
        144,
    ]
)
# kinematic tree
head_kin_chain = [15, 12, 9, 6, 3, 0]

# --smplx joints
# 00 - Global
# 01 - L_Thigh
# 02 - R_Thigh
# 03 - Spine
# 04 - L_Calf
# 05 - R_Calf
# 06 - Spine1
# 07 - L_Foot
# 08 - R_Foot
# 09 - Spine2
# 10 - L_Toes
# 11 - R_Toes
# 12 - Neck
# 13 - L_Shoulder
# 14 - R_Shoulder
# 15 - Head
# 16 - L_UpperArm
# 17 - R_UpperArm
# 18 - L_ForeArm
# 19 - R_ForeArm
# 20 - L_Hand
# 21 - R_Hand
# 22 - Jaw
# 23 - L_Eye
# 24 - R_Eye


class SMPLX(nn.Module):
    """
    Given smplx parameters, this class generates a differentiable SMPLX function
    which outputs a mesh and 3D joints
    """
    def __init__(self, config):
        super(SMPLX, self).__init__()
        # print("creating the SMPLX Decoder")
        ss = np.load(config.smplx_model_path, allow_pickle=True)
        smplx_model = Struct(**ss)

        self.dtype = torch.float32
        self.register_buffer(
            "faces_tensor",
            to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long),
        )
        # The vertices of the template model
        self.register_buffer(
            "v_template", to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)
        )
        # The shape components and expression
        # expression space is the same as FLAME
        shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype)
        shapedirs = torch.cat(
            [
                shapedirs[:, :, :config.n_shape],
                shapedirs[:, :, 300:300 + config.n_exp],
            ],
            2,
        )
        self.register_buffer("shapedirs", shapedirs)
        # The pose components
        num_pose_basis = smplx_model.posedirs.shape[-1]
        posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T
        self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype))
        self.register_buffer(
            "J_regressor", to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)
        )
        parents = to_tensor(to_np(smplx_model.kintree_table[0])).long()
        parents[0] = -1
        self.register_buffer("parents", parents)
        self.register_buffer("lbs_weights", to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
        # for face keypoints
        self.register_buffer(
            "lmk_faces_idx", torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)
        )
        self.register_buffer(
            "lmk_bary_coords",
            torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype),
        )
        self.register_buffer(
            "dynamic_lmk_faces_idx",
            torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long),
        )
        self.register_buffer(
            "dynamic_lmk_bary_coords",
            torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype),
        )
        # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks
        self.register_buffer("head_kin_chain", torch.tensor(head_kin_chain, dtype=torch.long))

        # -- initialize parameters
        # shape and expression
        self.register_buffer(
            "shape_params",
            nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False),
        )
        self.register_buffer(
            "expression_params",
            nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False),
        )
        # pose: represented as rotation matrx [number of joints, 3, 3]
        self.register_buffer(
            "global_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "head_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "neck_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "jaw_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "eye_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "body_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "left_hand_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "right_hand_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1),
                requires_grad=False,
            ),
        )

        if config.extra_joint_path:
            self.extra_joint_selector = JointsFromVerticesSelector(fname=config.extra_joint_path)
        self.use_joint_regressor = True
        self.keypoint_names = SMPLX_names
        if self.use_joint_regressor:
            with open(config.j14_regressor_path, "rb") as f:
                j14_regressor = pickle.load(f, encoding="latin1")
            source = []
            target = []
            for idx, name in enumerate(self.keypoint_names):
                if name in J14_NAMES:
                    source.append(idx)
                    target.append(J14_NAMES.index(name))
            source = np.asarray(source)
            target = np.asarray(target)
            self.register_buffer("source_idxs", torch.from_numpy(source))
            self.register_buffer("target_idxs", torch.from_numpy(target))
            joint_regressor = torch.from_numpy(j14_regressor).to(dtype=torch.float32)
            self.register_buffer("extra_joint_regressor", joint_regressor)
            self.part_indices = part_indices

    def forward(
        self,
        shape_params=None,
        expression_params=None,
        global_pose=None,
        body_pose=None,
        jaw_pose=None,
        eye_pose=None,
        left_hand_pose=None,
        right_hand_pose=None,
    ):
        """
        Args:
            shape_params: [N, number of shape parameters]
            expression_params: [N, number of expression parameters]
            global_pose: pelvis pose, [N, 1, 3, 3]
            body_pose: [N, 21, 3, 3]
            jaw_pose: [N, 1, 3, 3]
            eye_pose: [N, 2, 3, 3]
            left_hand_pose: [N, 15, 3, 3]
            right_hand_pose: [N, 15, 3, 3]
        Returns:
            vertices: [N, number of vertices, 3]
            landmarks: [N, number of landmarks (68 face keypoints), 3]
            joints: [N, number of smplx joints (145), 3]
        """
        if shape_params is None:
            batch_size = global_pose.shape[0]
            shape_params = self.shape_params.expand(batch_size, -1)
        else:
            batch_size = shape_params.shape[0]
        if expression_params is None:
            expression_params = self.expression_params.expand(batch_size, -1)
        if global_pose is None:
            global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if body_pose is None:
            body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if jaw_pose is None:
            jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if eye_pose is None:
            eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if left_hand_pose is None:
            left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if right_hand_pose is None:
            right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)

        shape_components = torch.cat([shape_params, expression_params], dim=1)
        full_pose = torch.cat(
            [
                global_pose,
                body_pose,
                jaw_pose,
                eye_pose,
                left_hand_pose,
                right_hand_pose,
            ],
            dim=1,
        )
        template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
        # smplx
        vertices, joints = lbs(
            shape_components,
            full_pose,
            template_vertices,
            self.shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            dtype=self.dtype,
            pose2rot=False,
        )
        # face dynamic landmarks
        lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
        dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
            vertices,
            full_pose,
            self.dynamic_lmk_faces_idx,
            self.dynamic_lmk_bary_coords,
            self.head_kin_chain,
        )
        lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
        lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1)
        landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)

        final_joint_set = [joints, landmarks]
        if hasattr(self, "extra_joint_selector"):
            # Add any extra joints that might be needed
            extra_joints = self.extra_joint_selector(vertices, self.faces_tensor)
            final_joint_set.append(extra_joints)
        # Create the final joint set
        joints = torch.cat(final_joint_set, dim=1)
        # if self.use_joint_regressor:
        #     reg_joints = torch.einsum("ji,bik->bjk",
        #                               self.extra_joint_regressor, vertices)
        #     joints[:, self.source_idxs] = (
        #         joints[:, self.source_idxs].detach() * 0.0 +
        #         reg_joints[:, self.target_idxs] * 1.0)
        return vertices, landmarks, joints

    def pose_abs2rel(self, global_pose, body_pose, abs_joint="head"):
        """change absolute pose to relative pose
        Basic knowledge for SMPLX kinematic tree:
                absolute pose = parent pose * relative pose
        Here, pose must be represented as rotation matrix (batch_sizexnx3x3)
        """
        if abs_joint == "head":
            # Pelvis -> Spine 1, 2, 3 -> Neck -> Head
            kin_chain = [15, 12, 9, 6, 3, 0]
        elif abs_joint == "neck":
            # Pelvis -> Spine 1, 2, 3 -> Neck -> Head
            kin_chain = [12, 9, 6, 3, 0]
        elif abs_joint == "right_wrist":
            # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder
            # -> right elbow -> right wrist
            kin_chain = [21, 19, 17, 14, 9, 6, 3, 0]
        elif abs_joint == "left_wrist":
            # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder
            # -> Left elbow -> Left wrist
            kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
        else:
            raise NotImplementedError(f"pose_abs2rel does not support: {abs_joint}")

        batch_size = global_pose.shape[0]
        dtype = global_pose.dtype
        device = global_pose.device
        full_pose = torch.cat([global_pose, body_pose], dim=1)
        rel_rot_mat = (
            torch.eye(3, device=device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1)
        )
        for idx in kin_chain[1:]:
            rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat)

        # This contains the absolute pose of the parent
        abs_parent_pose = rel_rot_mat.detach()
        # Let's assume that in the input this specific joint is predicted as an absolute value
        abs_joint_pose = body_pose[:, kin_chain[0] - 1]
        # abs_head = parents(abs_neck) * rel_head ==> rel_head = abs_neck.T * abs_head
        rel_joint_pose = torch.matmul(
            abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2),
            abs_joint_pose.reshape(-1, 3, 3),
        )
        # Replace the new relative pose
        body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose
        return body_pose

    def pose_rel2abs(self, global_pose, body_pose, abs_joint="head"):
        """change relative pose to absolute pose
        Basic knowledge for SMPLX kinematic tree:
                absolute pose = parent pose * relative pose
        Here, pose must be represented as rotation matrix (batch_sizexnx3x3)
        """
        full_pose = torch.cat([global_pose, body_pose], dim=1)

        if abs_joint == "head":
            # Pelvis -> Spine 1, 2, 3 -> Neck -> Head
            kin_chain = [15, 12, 9, 6, 3, 0]
        elif abs_joint == "neck":
            # Pelvis -> Spine 1, 2, 3 -> Neck -> Head
            kin_chain = [12, 9, 6, 3, 0]
        elif abs_joint == "right_wrist":
            # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder
            # -> right elbow -> right wrist
            kin_chain = [21, 19, 17, 14, 9, 6, 3, 0]
        elif abs_joint == "left_wrist":
            # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder
            # -> Left elbow -> Left wrist
            kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
        else:
            raise NotImplementedError(f"pose_rel2abs does not support: {abs_joint}")
        rel_rot_mat = torch.eye(3, device=full_pose.device, dtype=full_pose.dtype).unsqueeze_(dim=0)
        for idx in kin_chain:
            rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat)
        abs_pose = rel_rot_mat[:, None, :, :]
        return abs_pose
