{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6ae7b0e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from pytorch3d import transforms\n",
    "import torch\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e5d3fbb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/data/lmg/code/DART/data/original_1.pkl', 'rb') as f:\n",
    "    data_or = pickle.load(f)\n",
    "with open('/data/lmg/code/DART/data/stand_20fps.pkl', 'rb') as f:\n",
    "    data_stand = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d23986f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gender male\n",
      "betas (10,)\n",
      "transl (159, 3)\n",
      "global_orient (159, 3)\n",
      "body_pose (159, 63)\n",
      "pelvis_delta (3,)\n",
      "joints (159, 66)\n",
      "text {'proc_label': \"A person reciprocates the handshake by extending their right hand and grasping the other person's palm with similar firmness, while keeping their left hand holding their cards.\\n\", 'start_t': 0.0, 'end_t': 7.95}\n"
     ]
    }
   ],
   "source": [
    "for key in data_or.keys():\n",
    "    try:\n",
    "        print(key, data_or[key].shape)\n",
    "    except:\n",
    "        print(key, data_or[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0ae2a878",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in data_or.keys():\n",
    "    if type(data_or[key]) == torch.Tensor:\n",
    "        data_or[key] = np.array(data_or[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "702f1702",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in ['body_pose', 'joints']:\n",
    "    data_or[key] = data_or[key].reshape(data_or[key].shape[0],-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0f950c35",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gender male\n",
      "betas (10,)\n",
      "transl (21, 3)\n",
      "global_orient (21, 3)\n",
      "body_pose (21, 63)\n",
      "texts ['stand']\n"
     ]
    }
   ],
   "source": [
    "for key in data_stand.keys():\n",
    "    try:\n",
    "        print(key, data_stand[key].shape)\n",
    "    except:\n",
    "        print(key, data_stand[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "e8c98a7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "transl\n",
      "global_orient\n",
      "body_pose\n"
     ]
    }
   ],
   "source": [
    "for key in ['transl', 'global_orient', 'body_pose']:\n",
    "    data_stand[key][1]=data_or[key][1]\n",
    "    print(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "5abbcddd",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/data/lmg/code/DART/data/mixed.pkl', 'wb') as f:\n",
    "    pickle.dump(data_stand, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8ed37e5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "transl\n",
      "global_orient\n",
      "body_pose\n"
     ]
    }
   ],
   "source": [
    "for key in ['transl', 'global_orient', 'body_pose']:\n",
    "    data_or[key][1]=data_or[key][0]\n",
    "    print(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4f55a277",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/data/lmg/code/DART/data/t_pose.pkl', 'wb') as f:\n",
    "    pickle.dump(data_or, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "b74d42be",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/data/lmg/code/DART/data/t_pose.pkl', 'rb') as f:\n",
    "    t_pose = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "88df6cba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gender male\n",
      "betas torch.Size([1, 10])\n",
      "transl torch.Size([1, 159, 3])\n",
      "global_orient torch.Size([1, 159, 3, 3])\n",
      "body_pose torch.Size([1, 159, 21, 3, 3])\n",
      "pelvis_delta torch.Size([1, 3])\n",
      "joints torch.Size([1, 159, 66])\n",
      "text {'proc_label': \"A person reciprocates the handshake by extending their right hand and grasping the other person's palm with similar firmness, while keeping their left hand holding their cards.\\n\", 'start_t': 0.0, 'end_t': 7.95}\n"
     ]
    }
   ],
   "source": [
    "for key in t_pose.keys():\n",
    "    try:\n",
    "        print(key, t_pose[key].shape)\n",
    "    except:\n",
    "        print(key, t_pose[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "da44c846",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: You are using a SMPL+H model, with only 10 shape coefficients.\n",
      "WARNING: You are using a SMPL+H model, with only 10 shape coefficients.\n"
     ]
    }
   ],
   "source": [
    "import smplx\n",
    "device = 'cuda:4'\n",
    "body_model_dir = r'/data/lmg/code/DART/data/smplx_lockedhead_20230207/models_lockedhead'\n",
    "smplh_body_model_dict = {\n",
    "    'male': smplx.build_layer(body_model_dir, model_type='smplh',\n",
    "                                gender='male', ext='pkl',\n",
    "                                num_pca_comps=12),\n",
    "    'female': smplx.build_layer(body_model_dir, model_type='smplh',\n",
    "                                gender='female', ext='pkl',\n",
    "                                num_pca_comps=12\n",
    "                                )\n",
    "}\n",
    "bm_male = smplh_body_model_dict['male'].to(device).eval()\n",
    "bm_female = smplh_body_model_dict['female']\n",
    "def get_new_coordinate_cal(jts: torch.Tensor):\n",
    "    x_axis = jts[:, 2, :] - jts[:, 1, :]  # [b,3]\n",
    "    x_axis[:, -1] = 0\n",
    "    x_axis = x_axis / torch.norm(x_axis, dim=-1, keepdim=True)\n",
    "    z_axis = torch.FloatTensor([[0, 0, 1]]).to(jts.device).repeat(x_axis.shape[0], 1)\n",
    "    y_axis = torch.cross(z_axis, x_axis, dim=-1)\n",
    "    y_axis = y_axis / torch.norm(y_axis, dim=-1, keepdim=True)\n",
    "    new_rotmat = torch.stack([x_axis, y_axis, z_axis], dim=-1)  # [b,3,3]\n",
    "    new_transl = jts[:, :1]  # [b,1,3]\n",
    "    return new_rotmat, new_transl\n",
    "\n",
    "def get_new_coordinate(body_param_dict, use_predicted_joints=False, pred_joints=None):\n",
    "    if use_predicted_joints:\n",
    "        joints = pred_joints\n",
    "    else:\n",
    "        body_model = bm_male if body_param_dict['gender'] == 'male' else bm_female\n",
    "        joints = body_model(**body_param_dict).joints  # [b,J,3]\n",
    "    joints = pred_joints\n",
    "    new_rotmat, new_transl = get_new_coordinate_cal(joints)  # transformation from new coord axis to old coord axis\n",
    "\n",
    "    return new_rotmat, new_transl\n",
    "\n",
    "def calc_calibrate_offset(body_param_dict):\n",
    "    body_model = bm_male if body_param_dict['gender'] == 'male' else bm_female\n",
    "    smplx_out = body_model(betas=body_param_dict['betas'],\n",
    "                            # body_pose=body_param_dict['body_pose'],\n",
    "                            )\n",
    "    delta_T = smplx_out.joints[:, 0, :]  # [b, 3], we output all pelvis locations\n",
    "\n",
    "    return delta_T\n",
    "\n",
    "def canonicalize(primitive_dict, use_predicted_joints=False):\n",
    "    \"\"\"inplace canonicalize\n",
    "    primitive_dict:{\n",
    "    'transf_rotmat', 'transf_transl': [B, 3, 3], [B, 1, 3]\n",
    "    'gender': 'male' or 'female',\n",
    "    'betas': [B, T, 10],\n",
    "    'transl', 'global_orient', 'body_pose': [B, T, 3], [B, T, 3, 3], [B, T, 21, 3, 3]\n",
    "    'joints': optional, [B, T, 22*3],\n",
    "    }\n",
    "    \"\"\"\n",
    "    body_param_dict = {\n",
    "        'gender': primitive_dict['gender'],\n",
    "        'betas': primitive_dict['betas'][:, 0, :],\n",
    "        'transl': primitive_dict['transl'][:, 0, :],\n",
    "        'body_pose': primitive_dict['body_pose'][:, 0, :, :, :],\n",
    "        'global_orient': primitive_dict['global_orient'][:, 0, :, :],\n",
    "    }   # first frame bodies\n",
    "    # delta_T = self.calc_calibrate_offset(body_param_dict)  # [b,3]\n",
    "    delta_T = primitive_dict['pelvis_delta'] if 'pelvis_delta' in primitive_dict else calc_calibrate_offset(body_param_dict)  # [b,3]\n",
    "    transf_rotmat, transf_transl = get_new_coordinate(body_param_dict,\n",
    "                                                            use_predicted_joints=use_predicted_joints,\n",
    "                                                            pred_joints=primitive_dict['joints'][:, 0, :].reshape(-1, 22, 3) if 'joints' in primitive_dict else None\n",
    "                                                            )  # [b,3,3], [b,1,3]\n",
    "\n",
    "    transl = primitive_dict['transl']  # [b, T, 3]\n",
    "    global_ori = primitive_dict['global_orient']  # [b, T, 3, 3]\n",
    "    global_ori_new = torch.einsum('bij,btjk->btik', transf_rotmat.permute(0, 2, 1), global_ori)\n",
    "    transl = torch.einsum('bij,btj->bti', transf_rotmat.permute(0, 2, 1),\n",
    "                            transl + delta_T.unsqueeze(1) - transf_transl) - delta_T.unsqueeze(1)\n",
    "    primitive_dict['global_orient'] = global_ori_new\n",
    "    primitive_dict['transl'] = transl\n",
    "\n",
    "    if 'joints' in primitive_dict:\n",
    "        B, T, _ = primitive_dict['transl'].shape\n",
    "        joints = primitive_dict['joints'].reshape(B, T, 22, 3)  # [b, T, 22*3] -> [b, T, 22, 3]\n",
    "        joints = torch.einsum('bij,btkj->btki', transf_rotmat.permute(0, 2, 1), joints - transf_transl.unsqueeze(1))\n",
    "        primitive_dict['joints'] = joints.reshape(B, T, 22 * 3)  # [b, T, 22*3]\n",
    "\n",
    "    return transf_rotmat, transf_transl, primitive_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "9168cf63",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in t_pose.keys():\n",
    "    if type(t_pose[key]) == np.ndarray:\n",
    "        t_pose[key] = torch.Tensor(data_or[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "72d650b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_pose['global_orient'] = transforms.axis_angle_to_matrix(t_pose['global_orient'])\n",
    "t_pose['body_pose'] = transforms.axis_angle_to_matrix(t_pose['body_pose'].reshape(-1,21,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "68e6286a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in t_pose.keys():\n",
    "    try:\n",
    "        t_pose[key] = t_pose[key].unsqueeze(0).to(device)\n",
    "    except:\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "9b46c4a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_pose['betas'] = t_pose['betas'].unsqueeze(0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "6cf83201",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, canonicalized_t_pose = canonicalize(t_pose, use_predicted_joints=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "f548a34e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 159, 21, 3, 3])"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "canonicalized_t_pose['body_pose'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "f1df3ed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in canonicalized_t_pose.keys():\n",
    "    try:\n",
    "        canonicalized_t_pose[key] = canonicalized_t_pose[key].squeeze(0).to(device)\n",
    "    except:\n",
    "        continue\n",
    "canonicalized_t_pose['betas'] = canonicalized_t_pose['betas'].squeeze(0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b53546b",
   "metadata": {},
   "outputs": [],
   "source": [
    "canonicalized_t_pose['global_orient'] = transforms.matrix_to_axis_angle(canonicalized_t_pose['global_orient'])\n",
    "canonicalized_t_pose['body_pose'] = transforms.matrix_to_axis_angle(canonicalized_t_pose['body_pose']).reshape(-1,21*3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "58b52564",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in canonicalized_t_pose.keys():\n",
    "    if type(canonicalized_t_pose[key]) == torch.Tensor:\n",
    "        canonicalized_t_pose[key] = np.array(canonicalized_t_pose[key].detach().cpu())\n",
    "with open('/data/lmg/code/DART/data/t_pose_canonicalized.pkl', 'wb') as f:\n",
    "    pickle.dump(canonicalized_t_pose, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "182e0bcf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.9994,  0.0033,  0.0334],\n",
       "        [ 0.0335,  0.0172,  0.9993],\n",
       "        [ 0.0027,  0.9998, -0.0173]])"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.axis_angle_to_matrix(torch.tensor(canonicalized_t_pose['global_orient'][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "416f7e12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.        ,  0.        ,  0.        ],\n",
       "       [-0.05940776, -0.0171077 , -0.08180368],\n",
       "       [ 0.05952707, -0.0171077 , -0.0904277 ],\n",
       "       [-0.0053116 , -0.03606908,  0.12506074],\n",
       "       [-0.09140807, -0.00710184, -0.4693435 ],\n",
       "       [ 0.06510136, -0.0187335 , -0.47653306],\n",
       "       [-0.00844998, -0.00674394,  0.26254386],\n",
       "       [-0.06905508, -0.05025088, -0.8953259 ],\n",
       "       [ 0.01947643, -0.0591465 , -0.89399004],\n",
       "       [-0.00590783, -0.00300258,  0.31851214],\n",
       "       [-0.09368798,  0.0745691 , -0.9586828 ],\n",
       "       [ 0.07697382,  0.06185122, -0.9582008 ],\n",
       "       [ 0.00717213, -0.02964383,  0.5311343 ],\n",
       "       [-0.07776197, -0.01558292,  0.433286  ],\n",
       "       [ 0.07664081, -0.02561218,  0.4315073 ],\n",
       "       [ 0.00510446,  0.0510582 ,  0.59467906],\n",
       "       [-0.20107542, -0.02894975,  0.4794492 ],\n",
       "       [ 0.18970288, -0.03626017,  0.47831792],\n",
       "       [-0.45670983, -0.01846424,  0.4569344 ],\n",
       "       [ 0.45124424, -0.05514   ,  0.46881753],\n",
       "       [-0.7222548 , -0.03377725,  0.44870418],\n",
       "       [ 0.7136745 , -0.11371554,  0.45480543]], dtype=float32)"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "canonicalized_t_pose['joints'][0,].reshape(-1,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "15333cad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([159, 3, 3])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t_pose['global_orient'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1bf0abd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/data/lmg/code/DART/data/t_pose_canonicalized.pkl', 'rb') as f:\n",
    "    t_pose = pickle.load(f)\n",
    "t_pose['global_orient'] = transforms.axis_angle_to_matrix(torch.tensor(t_pose['global_orient']))\n",
    "identity = torch.eye(3).unsqueeze(0).repeat(t_pose['global_orient'].shape[0], 1, 1) \n",
    "t_pose['global_orient'] = (transforms.matrix_to_axis_angle(identity)).detach().cpu().numpy()\n",
    "with open('/data/lmg/code/DART/data/t_pose_canonicalized_v1.pkl', 'wb') as f:\n",
    "    pickle.dump(t_pose, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "intergen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
