{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "def collect_dataset_info(root_dir):\n",
    "    dataset_info = []\n",
    "\n",
    "    for mode in ['train', 'val', 'test']:\n",
    "        mode_dir = os.path.join(root_dir, mode)\n",
    "        if not os.path.exists(mode_dir):\n",
    "            continue\n",
    "\n",
    "        for video_folder in os.listdir(mode_dir):\n",
    "            video_id = video_folder  # Folder name is used as video_id\n",
    "            video_folder_path = os.path.join(mode_dir, video_folder)\n",
    "\n",
    "            if os.path.isdir(video_folder_path):\n",
    "                video_path, audio_path, motion_path = None, None, None\n",
    "\n",
    "                for file_name in os.listdir(video_folder_path):\n",
    "                    file_path = os.path.join(video_folder_path, file_name)\n",
    "\n",
    "                    if file_name.endswith('.mp4'):\n",
    "                        video_path = file_path\n",
    "                    elif file_name.endswith('.wav'):\n",
    "                        audio_path = file_path\n",
    "                    elif file_name.endswith('.pkl'):\n",
    "                        motion_path = file_path\n",
    "\n",
    "                # Create an entry only if all the necessary files are present\n",
    "                if video_path and audio_path and motion_path:\n",
    "                    dataset_info.append({\n",
    "                        \"video_id\": video_id,\n",
    "                        \"video_path\": video_path,\n",
    "                        \"audio_path\": audio_path,\n",
    "                        \"motion_path\": motion_path,\n",
    "                        \"mode\": mode\n",
    "                    })\n",
    "\n",
    "    return dataset_info\n",
    "\n",
    "\n",
    "# Set the root directory path of your dataset\n",
    "root_dir = '/path/to/ExpressiveWholeBodyDatasetReleaseV1.0'\n",
    "dataset_info = collect_dataset_info(root_dir)\n",
    "output_file = 'dataset_info.json'\n",
    "with open(output_file, 'w') as json_file:\n",
    "    json.dump(dataset_info, json_file, indent=4)\n",
    "print(f\"Dataset information saved to {output_file}\")\n",
    "\n",
    "\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "import wave\n",
    "\n",
    "def load_pkl(pkl_path):\n",
    "    try:\n",
    "        with open(pkl_path, 'rb') as f:\n",
    "            data = pickle.load(f)\n",
    "        return data\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading {pkl_path}: {e}\")\n",
    "        return None\n",
    "\n",
    "def load_wav(wav_path):\n",
    "    try:\n",
    "        with wave.open(wav_path, 'rb') as f:\n",
    "            frames = f.getnframes()\n",
    "        return frames\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading {wav_path}: {e}\")\n",
    "        return None\n",
    "\n",
    "def generate_clips(data, stride, window_length):\n",
    "    clips = []\n",
    "    for entry in data:\n",
    "        pkl_data = load_pkl(entry['motion_path'])\n",
    "        wav_frames = load_wav(entry['audio_path'])\n",
    "\n",
    "        # Only continue if both the pkl and wav files are successfully loaded\n",
    "        if pkl_data is None or wav_frames is None:\n",
    "            continue\n",
    "\n",
    "        # Determine the total length of the sequence from pkl data\n",
    "        total_frames = len(pkl_data)  # Assuming pkl contains motion data frames\n",
    "\n",
    "        # Generate clips based on stride and window_length\n",
    "        for start_idx in range(0, total_frames - window_length + 1, stride):\n",
    "            end_idx = start_idx + window_length\n",
    "            clip = {\n",
    "                \"video_id\": entry[\"video_id\"],\n",
    "                \"video_path\": entry[\"video_path\"],\n",
    "                \"audio_path\": entry[\"audio_path\"],\n",
    "                \"motion_path\": entry[\"motion_path\"],\n",
    "                \"mode\": entry[\"mode\"],\n",
    "                \"start_idx\": start_idx,\n",
    "                \"end_idx\": end_idx\n",
    "            }\n",
    "            clips.append(clip)\n",
    "\n",
    "    return clips\n",
    "\n",
    "\n",
    "# Load the existing dataset JSON file\n",
    "input_json = 'dataset_info.json'\n",
    "with open(input_json, 'r') as f:\n",
    "    dataset_info = json.load(f)\n",
    "\n",
    "# Set stride and window length\n",
    "stride = 5  # Adjust stride as needed\n",
    "window_length = 10  # Adjust window length as needed\n",
    "\n",
    "# Generate clips for all data\n",
    "clips_data = generate_clips(dataset_info, stride, window_length)\n",
    "\n",
    "# Save the filtered clips data to a new JSON file\n",
    "output_json = 'filtered_clips_data.json'\n",
    "with open(output_json, 'w') as f:\n",
    "    json.dump(clips_data, f, indent=4)\n",
    "\n",
    "print(f\"Filtered clips data saved to {output_json}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "A module that was compiled using NumPy 1.x cannot be run in\n",
      "NumPy 2.1.1 as it may crash. To support both 1.x and 2.x\n",
      "versions of NumPy, modules must be compiled with NumPy 2.0.\n",
      "Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.\n",
      "\n",
      "If you are a user of the module, the easiest solution will be to\n",
      "downgrade to 'numpy<2' or try to upgrade the affected module.\n",
      "We expect that some modules will need time to support NumPy 2.\n",
      "\n",
      "Traceback (most recent call last):  File \"<frozen runpy>\", line 198, in _run_module_as_main\n",
      "  File \"<frozen runpy>\", line 88, in _run_code\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/ipykernel_launcher.py\", line 18, in <module>\n",
      "    app.launch_new_instance()\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n",
      "    app.start()\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/ipykernel/kernelapp.py\", line 739, in start\n",
      "    self.io_loop.start()\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/tornado/platform/asyncio.py\", line 205, in start\n",
      "    self.asyncio_loop.run_forever()\n",
      "  File \"/usr/local/Cellar/python@3.11/3.11.9/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py\", line 608, in run_forever\n",
      "    self._run_once()\n",
      "  File \"/usr/local/Cellar/python@3.11/3.11.9/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/base_events.py\", line 1936, in _run_once\n",
      "    handle._run()\n",
      "  File \"/usr/local/Cellar/python@3.11/3.11.9/Frameworks/Python.framework/Versions/3.11/lib/python3.11/asyncio/events.py\", line 84, in _run\n",
      "    self._context.run(self._callback, *self._args)\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/ipykernel/kernelbase.py\", line 545, in dispatch_queue\n",
      "    await self.process_one()\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/ipykernel/kernelbase.py\", line 534, in process_one\n",
      "    await dispatch(*args)\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/ipykernel/kernelbase.py\", line 437, in dispatch_shell\n",
      "    await result\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/ipykernel/ipkernel.py\", line 362, in execute_request\n",
      "    await super().execute_request(stream, ident, parent)\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/ipykernel/kernelbase.py\", line 778, in execute_request\n",
      "    reply_content = await reply_content\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/ipykernel/ipkernel.py\", line 449, in do_execute\n",
      "    res = shell.run_cell(\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n",
      "    return super().run_cell(*args, **kwargs)\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/IPython/core/interactiveshell.py\", line 3075, in run_cell\n",
      "    result = self._run_cell(\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/IPython/core/interactiveshell.py\", line 3130, in _run_cell\n",
      "    result = runner(coro)\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/IPython/core/async_helpers.py\", line 128, in _pseudo_sync_runner\n",
      "    coro.send(None)\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/IPython/core/interactiveshell.py\", line 3334, in run_cell_async\n",
      "    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/IPython/core/interactiveshell.py\", line 3517, in run_ast_nodes\n",
      "    if await self.run_code(code, result, async_=asy):\n",
      "  File \"/Users/haiyang/Library/Python/3.11/lib/python/site-packages/IPython/core/interactiveshell.py\", line 3577, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"/var/folders/fs/6rm22wts5gd2ll3xz0snkhl00000gn/T/ipykernel_89872/2585702779.py\", line 1, in <module>\n",
      "    import torch\n",
      "  File \"/usr/local/lib/python3.11/site-packages/torch/__init__.py\", line 1477, in <module>\n",
      "    from .functional import *  # noqa: F403\n",
      "  File \"/usr/local/lib/python3.11/site-packages/torch/functional.py\", line 9, in <module>\n",
      "    import torch.nn.functional as F\n",
      "  File \"/usr/local/lib/python3.11/site-packages/torch/nn/__init__.py\", line 1, in <module>\n",
      "    from .modules import *  # noqa: F403\n",
      "  File \"/usr/local/lib/python3.11/site-packages/torch/nn/modules/__init__.py\", line 35, in <module>\n",
      "    from .transformer import TransformerEncoder, TransformerDecoder, \\\n",
      "  File \"/usr/local/lib/python3.11/site-packages/torch/nn/modules/transformer.py\", line 20, in <module>\n",
      "    device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),\n",
      "/usr/local/lib/python3.11/site-packages/torch/nn/modules/transformer.py:20: UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:84.)\n",
      "  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 6\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpickle\u001b[39;00m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/Users/haiyang/download_backup/oliver/Abortion_Laws_-_Last_Week_Tonight_with_John_Oliver_HBO-DRauXXz6t0Y.webm/test/214438-00_07_16-00_07_26/214438-00_07_16-00_07_26.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrb\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m      5\u001b[0m     \u001b[38;5;66;03m# Load the file by mapping any GPU tensors to CPU\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m     pkl_example \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcpu\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;66;03m# Now check the type of the object\u001b[39;00m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mtype\u001b[39m(pkl_example))\n",
      "File \u001b[0;32m/usr/local/lib/python3.11/site-packages/torch/serialization.py:1040\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[1;32m   1038\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m   1039\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m pickle\u001b[38;5;241m.\u001b[39mUnpicklingError(UNSAFE_MESSAGE \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(e)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1040\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_legacy_load\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopened_file\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_module\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpickle_load_args\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.11/site-packages/torch/serialization.py:1258\u001b[0m, in \u001b[0;36m_legacy_load\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m   1252\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(f, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreadinto\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m8\u001b[39m, \u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39mversion_info \u001b[38;5;241m<\u001b[39m (\u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m8\u001b[39m, \u001b[38;5;241m2\u001b[39m):\n\u001b[1;32m   1253\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m   1254\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1255\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mReceived object of type \u001b[39m\u001b[38;5;130;01m\\\"\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(f)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\\"\u001b[39;00m\u001b[38;5;124m. Please update to Python 3.8.2 or newer to restore this \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1256\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfunctionality.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 1258\u001b[0m magic_number \u001b[38;5;241m=\u001b[39m \u001b[43mpickle_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpickle_load_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1259\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m magic_number \u001b[38;5;241m!=\u001b[39m MAGIC_NUMBER:\n\u001b[1;32m   1260\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid magic number; corrupt file?\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m/usr/local/lib/python3.11/site-packages/torch/storage.py:371\u001b[0m, in \u001b[0;36m_load_from_bytes\u001b[0;34m(b)\u001b[0m\n\u001b[1;32m    370\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_load_from_bytes\u001b[39m(b):\n\u001b[0;32m--> 371\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mBytesIO\u001b[49m\u001b[43m(\u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.11/site-packages/torch/serialization.py:1040\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[1;32m   1038\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m   1039\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m pickle\u001b[38;5;241m.\u001b[39mUnpicklingError(UNSAFE_MESSAGE \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(e)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1040\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_legacy_load\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopened_file\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_module\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpickle_load_args\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.11/site-packages/torch/serialization.py:1268\u001b[0m, in \u001b[0;36m_legacy_load\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m   1266\u001b[0m unpickler \u001b[38;5;241m=\u001b[39m UnpicklerWrapper(f, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_load_args)\n\u001b[1;32m   1267\u001b[0m unpickler\u001b[38;5;241m.\u001b[39mpersistent_load \u001b[38;5;241m=\u001b[39m persistent_load\n\u001b[0;32m-> 1268\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43munpickler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1270\u001b[0m deserialized_storage_keys \u001b[38;5;241m=\u001b[39m pickle_module\u001b[38;5;241m.\u001b[39mload(f, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_load_args)\n\u001b[1;32m   1272\u001b[0m offset \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mtell() \u001b[38;5;28;01mif\u001b[39;00m f_should_read_directly \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.11/site-packages/torch/serialization.py:1205\u001b[0m, in \u001b[0;36m_legacy_load.<locals>.persistent_load\u001b[0;34m(saved_id)\u001b[0m\n\u001b[1;32m   1201\u001b[0m     obj\u001b[38;5;241m.\u001b[39m_torch_load_uninitialized \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m   1202\u001b[0m     \u001b[38;5;66;03m# TODO: Once we decide to break serialization FC, we can\u001b[39;00m\n\u001b[1;32m   1203\u001b[0m     \u001b[38;5;66;03m# stop wrapping with TypedStorage\u001b[39;00m\n\u001b[1;32m   1204\u001b[0m     typed_storage \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstorage\u001b[38;5;241m.\u001b[39mTypedStorage(\n\u001b[0;32m-> 1205\u001b[0m         wrap_storage\u001b[38;5;241m=\u001b[39m\u001b[43mrestore_location\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlocation\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m   1206\u001b[0m         dtype\u001b[38;5;241m=\u001b[39mdtype,\n\u001b[1;32m   1207\u001b[0m         _internal\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m   1208\u001b[0m     deserialized_objects[root_key] \u001b[38;5;241m=\u001b[39m typed_storage\n\u001b[1;32m   1209\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m/usr/local/lib/python3.11/site-packages/torch/serialization.py:391\u001b[0m, in \u001b[0;36mdefault_restore_location\u001b[0;34m(storage, location)\u001b[0m\n\u001b[1;32m    389\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdefault_restore_location\u001b[39m(storage, location):\n\u001b[1;32m    390\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m _, _, fn \u001b[38;5;129;01min\u001b[39;00m _package_registry:\n\u001b[0;32m--> 391\u001b[0m         result \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstorage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlocation\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    392\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    393\u001b[0m             \u001b[38;5;28;01mreturn\u001b[39;00m result\n",
      "File \u001b[0;32m/usr/local/lib/python3.11/site-packages/torch/serialization.py:266\u001b[0m, in \u001b[0;36m_cuda_deserialize\u001b[0;34m(obj, location)\u001b[0m\n\u001b[1;32m    264\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_cuda_deserialize\u001b[39m(obj, location):\n\u001b[1;32m    265\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m location\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[0;32m--> 266\u001b[0m         device \u001b[38;5;241m=\u001b[39m \u001b[43mvalidate_cuda_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlocation\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    267\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(obj, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_torch_load_uninitialized\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m    268\u001b[0m             \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mdevice(device):\n",
      "File \u001b[0;32m/usr/local/lib/python3.11/site-packages/torch/serialization.py:250\u001b[0m, in \u001b[0;36mvalidate_cuda_device\u001b[0;34m(location)\u001b[0m\n\u001b[1;32m    247\u001b[0m device \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39m_utils\u001b[38;5;241m.\u001b[39m_get_device_index(location, \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m    249\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available():\n\u001b[0;32m--> 250\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAttempting to deserialize object on a CUDA \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m    251\u001b[0m                        \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdevice but torch.cuda.is_available() is False. \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m    252\u001b[0m                        \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mIf you are running on a CPU-only machine, \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m    253\u001b[0m                        \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mplease use torch.load with map_location=torch.device(\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m) \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m    254\u001b[0m                        \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mto map your storages to the CPU.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    255\u001b[0m device_count \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mdevice_count()\n\u001b[1;32m    256\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m device \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m device_count:\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU."
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import pickle\n",
    "\n",
    "with open(\"/Users/haiyang/download_backup/oliver/Abortion_Laws_-_Last_Week_Tonight_with_John_Oliver_HBO-DRauXXz6t0Y.webm/test/214438-00_07_16-00_07_26/214438-00_07_16-00_07_26.pkl\", \"rb\") as f:\n",
    "    # Load the file by mapping any GPU tensors to CPU\n",
    "    pkl_example = torch.load(f, map_location=torch.device('cpu'))\n",
    "\n",
    "# Now check the type of the object\n",
    "print(type(pkl_example))\n",
    "\n",
    "# If it's a dictionary, print its keys\n",
    "if isinstance(pkl_example, dict):\n",
    "    print(pkl_example.keys())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
