{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "import dataclasses\n",
    "import gzip\n",
    "import json\n",
    "from dataclasses import dataclass, Field, MISSING\n",
    "from typing import Any, cast, Dict, IO, Optional, Tuple, Type, TypeVar, Union\n",
    "from typing import get_args, get_origin\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "_X = TypeVar(\"_X\")\n",
    "\n",
    "TF3 = Tuple[float, float, float]\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ImageAnnotation:\n",
    "    # path to jpg file, relative w.r.t. dataset_root\n",
    "    path: str\n",
    "    # H x W\n",
    "    size: Tuple[int, int]  # TODO: rename size_hw?\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class DepthAnnotation:\n",
    "    # path to png file, relative w.r.t. dataset_root, storing `depth / scale_adjustment`\n",
    "    path: str\n",
    "    # a factor to convert png values to actual depth: `depth = png * scale_adjustment`\n",
    "    scale_adjustment: float\n",
    "    # path to png file, relative w.r.t. dataset_root, storing binary `depth` mask\n",
    "    mask_path: Optional[str]\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class MaskAnnotation:\n",
    "    # path to png file storing (Prob(fg | pixel) * 255)\n",
    "    path: str\n",
    "    # (soft) number of pixels in the mask; sum(Prob(fg | pixel))\n",
    "    mass: Optional[float] = None\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ViewpointAnnotation:\n",
    "    # In right-multiply (PyTorch3D) format. X_cam = X_world @ R + T\n",
    "    R: Tuple[TF3, TF3, TF3]\n",
    "    T: TF3\n",
    "\n",
    "    focal_length: Tuple[float, float]\n",
    "    principal_point: Tuple[float, float]\n",
    "\n",
    "    intrinsics_format: str = \"ndc_norm_image_bounds\"\n",
    "    # Defines the co-ordinate system where focal_length and principal_point live.\n",
    "    # Possible values: ndc_isotropic | ndc_norm_image_bounds (default)\n",
    "    # ndc_norm_image_bounds: legacy PyTorch3D NDC format, where image boundaries\n",
    "    #     correspond to [-1, 1] x [-1, 1], and the scale along x and y may differ\n",
    "    # ndc_isotropic: PyTorch3D 0.5+ NDC convention where the shorter side has\n",
    "    #     the range [-1, 1], and the longer one has the range [-s, s]; s >= 1,\n",
    "    #     where s is the aspect ratio. The scale is same along x and y.\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class FrameAnnotation:\n",
    "    \"\"\"A dataclass used to load annotations from json.\"\"\"\n",
    "\n",
    "    # can be used to join with `SequenceAnnotation`\n",
    "    sequence_name: str\n",
    "    # 0-based, continuous frame number within sequence\n",
    "    frame_number: int\n",
    "    # timestamp in seconds from the video start\n",
    "    frame_timestamp: float\n",
    "\n",
    "    image: ImageAnnotation\n",
    "    depth: Optional[DepthAnnotation] = None\n",
    "    mask: Optional[MaskAnnotation] = None\n",
    "    viewpoint: Optional[ViewpointAnnotation] = None\n",
    "    meta: Optional[Dict[str, Any]] = None\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class PointCloudAnnotation:\n",
    "    # path to ply file with points only, relative w.r.t. dataset_root\n",
    "    path: str\n",
    "    # the bigger the better\n",
    "    quality_score: float\n",
    "    n_points: Optional[int]\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class VideoAnnotation:\n",
    "    # path to the original video file, relative w.r.t. dataset_root\n",
    "    path: str\n",
    "    # length of the video in seconds\n",
    "    length: float\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class SequenceAnnotation:\n",
    "    sequence_name: str\n",
    "    category: str\n",
    "    video: Optional[VideoAnnotation] = None\n",
    "    point_cloud: Optional[PointCloudAnnotation] = None\n",
    "    # the bigger the better\n",
    "    viewpoint_quality_score: Optional[float] = None\n",
    "\n",
    "\n",
    "def dump_dataclass(obj: Any, f: IO, binary: bool = False) -> None:\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        f: Either a path to a file, or a file opened for writing.\n",
    "        obj: A @dataclass or collection hierarchy including dataclasses.\n",
    "        binary: Set to True if `f` is a file handle, else False.\n",
    "    \"\"\"\n",
    "    if binary:\n",
    "        f.write(json.dumps(_asdict_rec(obj)).encode(\"utf8\"))\n",
    "    else:\n",
    "        json.dump(_asdict_rec(obj), f)\n",
    "\n",
    "\n",
    "def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:\n",
    "    \"\"\"\n",
    "    Loads to a @dataclass or collection hierarchy including dataclasses\n",
    "    from a json recursively.\n",
    "    Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).\n",
    "    raises KeyError if json has keys not mapping to the dataclass fields.\n",
    "\n",
    "    Args:\n",
    "        f: Either a path to a file, or a file opened for writing.\n",
    "        cls: The class of the loaded dataclass.\n",
    "        binary: Set to True if `f` is a file handle, else False.\n",
    "    \"\"\"\n",
    "    if binary:\n",
    "        asdict = json.loads(f.read().decode(\"utf8\"))\n",
    "    else:\n",
    "        asdict = json.load(f)\n",
    "\n",
    "    if isinstance(asdict, list):\n",
    "        # in the list case, run a faster \"vectorized\" version\n",
    "        cls = get_args(cls)[0]\n",
    "        res = list(_dataclass_list_from_dict_list(asdict, cls))\n",
    "    else:\n",
    "        res = _dataclass_from_dict(asdict, cls)\n",
    "\n",
    "    return res\n",
    "\n",
    "\n",
    "def _dataclass_list_from_dict_list(dlist, typeannot):\n",
    "    \"\"\"\n",
    "    Vectorised version of `_dataclass_from_dict`.\n",
    "    The output should be equivalent to\n",
    "    `[_dataclass_from_dict(d, typeannot) for d in dlist]`.\n",
    "\n",
    "    Args:\n",
    "        dlist: list of objects to convert.\n",
    "        typeannot: type of each of those objects.\n",
    "    Returns:\n",
    "        iterator or list over converted objects of the same length as `dlist`.\n",
    "\n",
    "    Raises:\n",
    "        ValueError: it assumes the objects have None's in consistent places across\n",
    "            objects, otherwise it would ignore some values. This generally holds for\n",
    "            auto-generated annotations, but otherwise use `_dataclass_from_dict`.\n",
    "    \"\"\"\n",
    "\n",
    "    cls = get_origin(typeannot) or typeannot\n",
    "\n",
    "    if typeannot is Any:\n",
    "        return dlist\n",
    "    if all(obj is None for obj in dlist):  # 1st recursion base: all None nodes\n",
    "        return dlist\n",
    "    if any(obj is None for obj in dlist):\n",
    "        # filter out Nones and recurse on the resulting list\n",
    "        idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]\n",
    "        idx, notnone = zip(*idx_notnone)\n",
    "        converted = _dataclass_list_from_dict_list(notnone, typeannot)\n",
    "        res = [None] * len(dlist)\n",
    "        for i, obj in zip(idx, converted):\n",
    "            res[i] = obj\n",
    "        return res\n",
    "\n",
    "    is_optional, contained_type = _resolve_optional(typeannot)\n",
    "    if is_optional:\n",
    "        return _dataclass_list_from_dict_list(dlist, contained_type)\n",
    "\n",
    "    # otherwise, we dispatch by the type of the provided annotation to convert to\n",
    "    if issubclass(cls, tuple) and hasattr(cls, \"_fields\"):  # namedtuple\n",
    "        # For namedtuple, call the function recursively on the lists of corresponding keys\n",
    "        types = cls._field_types.values()\n",
    "        dlist_T = zip(*dlist)\n",
    "        res_T = [\n",
    "            _dataclass_list_from_dict_list(key_list, tp)\n",
    "            for key_list, tp in zip(dlist_T, types)\n",
    "        ]\n",
    "        return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]\n",
    "    elif issubclass(cls, (list, tuple)):\n",
    "        # For list/tuple, call the function recursively on the lists of corresponding positions\n",
    "        types = get_args(typeannot)\n",
    "        if len(types) == 1:  # probably List; replicate for all items\n",
    "            types = types * len(dlist[0])\n",
    "        dlist_T = zip(*dlist)\n",
    "        res_T = (\n",
    "            _dataclass_list_from_dict_list(pos_list, tp)\n",
    "            for pos_list, tp in zip(dlist_T, types)\n",
    "        )\n",
    "        if issubclass(cls, tuple):\n",
    "            return list(zip(*res_T))\n",
    "        else:\n",
    "            return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]\n",
    "    elif issubclass(cls, dict):\n",
    "        # For the dictionary, call the function recursively on concatenated keys and vertices\n",
    "        key_t, val_t = get_args(typeannot)\n",
    "        all_keys_res = _dataclass_list_from_dict_list(\n",
    "            [k for obj in dlist for k in obj.keys()], key_t\n",
    "        )\n",
    "        all_vals_res = _dataclass_list_from_dict_list(\n",
    "            [k for obj in dlist for k in obj.values()], val_t\n",
    "        )\n",
    "        indices = np.cumsum([len(obj) for obj in dlist])\n",
    "        assert indices[-1] == len(all_keys_res)\n",
    "\n",
    "        keys = np.split(list(all_keys_res), indices[:-1])\n",
    "        # vals = np.split(all_vals_res, indices[:-1])\n",
    "        all_vals_res_iter = iter(all_vals_res)\n",
    "        return [cls(zip(k, all_vals_res_iter)) for k in keys]\n",
    "    elif not dataclasses.is_dataclass(typeannot):\n",
    "        return dlist\n",
    "\n",
    "    # dataclass node: 2nd recursion base; call the function recursively on the lists\n",
    "    # of the corresponding fields\n",
    "    assert dataclasses.is_dataclass(cls)\n",
    "    fieldtypes = {\n",
    "        f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))\n",
    "        for f in dataclasses.fields(typeannot)\n",
    "    }\n",
    "\n",
    "    # NOTE the default object is shared here\n",
    "    key_lists = (\n",
    "        _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)\n",
    "        for k, (type_, default) in fieldtypes.items()\n",
    "    )\n",
    "    transposed = zip(*key_lists)\n",
    "    return [cls(*vals_as_tuple) for vals_as_tuple in transposed]\n",
    "\n",
    "\n",
    "def _dataclass_from_dict(d, typeannot):\n",
    "    if d is None or typeannot is Any:\n",
    "        return d\n",
    "    is_optional, contained_type = _resolve_optional(typeannot)\n",
    "    if is_optional:\n",
    "        # an Optional not set to None, just use the contents of the Optional.\n",
    "        return _dataclass_from_dict(d, contained_type)\n",
    "\n",
    "    cls = get_origin(typeannot) or typeannot\n",
    "    if issubclass(cls, tuple) and hasattr(cls, \"_fields\"):  # namedtuple\n",
    "        types = cls._field_types.values()\n",
    "        return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])\n",
    "    elif issubclass(cls, (list, tuple)):\n",
    "        types = get_args(typeannot)\n",
    "        if len(types) == 1:  # probably List; replicate for all items\n",
    "            types = types * len(d)\n",
    "        return cls(_dataclass_from_dict(v, tp) for v, tp in zip(d, types))\n",
    "    elif issubclass(cls, dict):\n",
    "        key_t, val_t = get_args(typeannot)\n",
    "        return cls(\n",
    "            (_dataclass_from_dict(k, key_t), _dataclass_from_dict(v, val_t))\n",
    "            for k, v in d.items()\n",
    "        )\n",
    "    elif not dataclasses.is_dataclass(typeannot):\n",
    "        return d\n",
    "\n",
    "    assert dataclasses.is_dataclass(cls)\n",
    "    fieldtypes = {f.name: _unwrap_type(f.type) for f in dataclasses.fields(typeannot)}\n",
    "    return cls(**{k: _dataclass_from_dict(v, fieldtypes[k]) for k, v in d.items()})\n",
    "\n",
    "\n",
    "def _unwrap_type(tp):\n",
    "    # strips Optional wrapper, if any\n",
    "    if get_origin(tp) is Union:\n",
    "        args = get_args(tp)\n",
    "        if len(args) == 2 and any(a is type(None) for a in args):  # noqa: E721\n",
    "            # this is typing.Optional\n",
    "            return args[0] if args[1] is type(None) else args[1]  # noqa: E721\n",
    "    return tp\n",
    "\n",
    "\n",
    "def _get_dataclass_field_default(field: Field) -> Any:\n",
    "    if field.default_factory is not MISSING:\n",
    "        # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,\n",
    "        #  dataclasses._DefaultFactory[typing.Any]]` is not a function.\n",
    "        return field.default_factory()\n",
    "    elif field.default is not MISSING:\n",
    "        return field.default\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "\n",
    "def _asdict_rec(obj):\n",
    "    return dataclasses._asdict_inner(obj, dict)\n",
    "\n",
    "\n",
    "def dump_dataclass_jgzip(outfile: str, obj: Any) -> None:\n",
    "    \"\"\"\n",
    "    Dumps obj to a gzipped json outfile.\n",
    "\n",
    "    Args:\n",
    "        obj: A @dataclass or collection hiererchy including dataclasses.\n",
    "        outfile: The path to the output file.\n",
    "    \"\"\"\n",
    "    with gzip.GzipFile(outfile, \"wb\") as f:\n",
    "        dump_dataclass(obj, cast(IO, f), binary=True)\n",
    "\n",
    "\n",
    "def load_dataclass_jgzip(outfile, cls):\n",
    "    \"\"\"\n",
    "    Loads a dataclass from a gzipped json outfile.\n",
    "\n",
    "    Args:\n",
    "        outfile: The path to the loaded file.\n",
    "        cls: The type annotation of the loaded dataclass.\n",
    "\n",
    "    Returns:\n",
    "        loaded_dataclass: The loaded dataclass.\n",
    "    \"\"\"\n",
    "    with gzip.GzipFile(outfile, \"rb\") as f:\n",
    "        return load_dataclass(cast(IO, f), cls, binary=True)\n",
    "\n",
    "\n",
    "def _resolve_optional(type_: Any) -> Tuple[bool, Any]:\n",
    "    \"\"\"Check whether `type_` is equivalent to `typing.Optional[T]` for some T.\"\"\"\n",
    "    if get_origin(type_) is Union:\n",
    "        args = get_args(type_)\n",
    "        if len(args) == 2 and args[1] == type(None):  # noqa E721\n",
    "            return True, args[0]\n",
    "    if type_ is Any:\n",
    "        return True, Any\n",
    "\n",
    "    return False, type_\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from pytorch3d.renderer.camera_utils import join_cameras_as_batch\n",
    "import torch\n",
    "from pytorch3d.renderer.cameras import PerspectiveCameras\n",
    "\n",
    "\n",
    "def get_pytorch3d_camera(\n",
    "    entry: FrameAnnotation,\n",
    ") -> PerspectiveCameras:\n",
    "    entry_viewpoint = entry.viewpoint\n",
    "    assert entry_viewpoint is not None\n",
    "    # principal point and focal length\n",
    "    principal_point = torch.tensor(entry_viewpoint.principal_point, dtype=torch.float)\n",
    "    focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)\n",
    "\n",
    "    format = entry_viewpoint.intrinsics_format\n",
    "    if entry_viewpoint.intrinsics_format == \"ndc_norm_image_bounds\":\n",
    "        # legacy PyTorch3D NDC format\n",
    "        # convert to pixels unequally and convert to ndc equally\n",
    "        image_size_as_list = list(reversed(entry.image.size))\n",
    "        image_size_wh = torch.tensor(image_size_as_list, dtype=torch.float)\n",
    "        per_axis_scale = image_size_wh / image_size_wh.min()\n",
    "        focal_length = focal_length * per_axis_scale\n",
    "        principal_point = principal_point * per_axis_scale\n",
    "    elif entry_viewpoint.intrinsics_format != \"ndc_isotropic\":\n",
    "        raise ValueError(f\"Unknown intrinsics format: {format}\")\n",
    "\n",
    "    return PerspectiveCameras(\n",
    "        focal_length=focal_length[None],\n",
    "        principal_point=principal_point[None],\n",
    "        R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],\n",
    "        T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],\n",
    "    )\n",
    "\n",
    "class AdditionCameras:\n",
    "    def __init__(self, category_name=\"teddybear\"):\n",
    "        self.category_frame_annotations = load_dataclass_jgzip(\n",
    "            f\"/data/datasets/co3d/{category_name}/frame_annotations.jgz\",\n",
    "            List[FrameAnnotation],\n",
    "        )\n",
    "        self.category_sequence_annotations = load_dataclass_jgzip(\n",
    "            f\"/data/datasets/co3d/{category_name}/sequence_annotations.jgz\",\n",
    "            List[SequenceAnnotation],\n",
    "        )\n",
    "        self.seq_cache = {}\n",
    "        self.camera_cache = {}\n",
    "\n",
    "    def get_sequence(self, sequence_name, num_frames=4, determinstic=True):\n",
    "        if sequence_name in self.seq_cache:\n",
    "            return self.seq_cache[sequence_name]\n",
    "\n",
    "        sequence = [\n",
    "            frame\n",
    "            for frame in self.category_frame_annotations\n",
    "            if frame.sequence_name == sequence_name\n",
    "        ]\n",
    "        if len(sequence) < num_frames:\n",
    "            print(\n",
    "                f\"Warning: sequence {sequence_name} has less than {num_frames} frames\"\n",
    "            )\n",
    "\n",
    "        if determinstic:\n",
    "            # equal intervals to choose the num_frames samples\n",
    "            sample_intervals = len(sequence) // num_frames\n",
    "            sequence = sorted(sequence, key=lambda x: x.frame_number)\n",
    "            sequence = sequence[::sample_intervals][:num_frames]\n",
    "\n",
    "        else:\n",
    "            # random sampling\n",
    "            sequence = random.sample(sequence, num_frames)\n",
    "\n",
    "        self.seq_cache[sequence_name] = sequence\n",
    "        return sequence\n",
    "\n",
    "    def get_camera(self, sequence_name, num_frames=4, determinstic=True, device=\"cuda\"):\n",
    "        # add cache to avoid loading the same sequence multiple times\n",
    "        if sequence_name in self.camera_cache:\n",
    "            return self.camera_cache[sequence_name].to(device)\n",
    "        \n",
    "        else:\n",
    "            sequence = self.get_sequence(sequence_name, num_frames, determinstic)\n",
    "            cameras = [get_pytorch3d_camera(x) for x in sequence]\n",
    "            camera_batch = join_cameras_as_batch(cameras)\n",
    "            self.camera_cache[sequence_name] = camera_batch.to(\"cpu\")\n",
    "            return camera_batch.to(device)\n",
    "        \n",
    "a = AdditionCameras()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[FrameAnnotation(sequence_name='598_91820_182418', frame_number=1, frame_timestamp=0.0, image=ImageAnnotation(path='teddybear/598_91820_182418/images/frame000001.jpg', size=(1230, 691)), depth=DepthAnnotation(path='teddybear/598_91820_182418/depths/frame000001.jpg.geometric.png', scale_adjustment=1.0, mask_path='teddybear/598_91820_182418/depth_masks/frame000001.png'), mask=MaskAnnotation(path='teddybear/598_91820_182418/masks/frame000001.png', mass=133561.0), viewpoint=ViewpointAnnotation(R=((-0.9758423566818237, -0.20728996396064758, -0.06901119649410248), (0.19773271679878235, -0.9723125100135803, 0.12454000860452652), (-0.092916339635849, 0.10788564383983612, 0.9898117184638977)), T=(0.8250232338905334, 1.3136308193206787, 12.107494354248047), focal_length=(3.694533348083496, 3.694533348083496), principal_point=(-0.0, -0.0), intrinsics_format='ndc_isotropic'), meta={'frame_type': 'test_known', 'frame_splits': ['singlesequence_teddybear_test_0_known'], 'eval_batch_maps': []}),\n",
       " FrameAnnotation(sequence_name='598_91820_182418', frame_number=51, frame_timestamp=4.9950248756218905, image=ImageAnnotation(path='teddybear/598_91820_182418/images/frame000051.jpg', size=(1230, 691)), depth=DepthAnnotation(path='teddybear/598_91820_182418/depths/frame000051.jpg.geometric.png', scale_adjustment=1.0, mask_path='teddybear/598_91820_182418/depth_masks/frame000051.png'), mask=MaskAnnotation(path='teddybear/598_91820_182418/masks/frame000051.png', mass=85380.0), viewpoint=ViewpointAnnotation(R=((0.1951938271522522, 0.737608015537262, 0.6464006304740906), (-0.7145118117332458, -0.34453338384628296, 0.6089085340499878), (0.6718424558639526, -0.580716073513031, 0.4597788453102112)), T=(-0.6224270462989807, 0.7326756715774536, 12.550195693969727), focal_length=(3.6733319759368896, 3.6733319759368896), principal_point=(-0.0, -0.0), intrinsics_format='ndc_isotropic'), meta={'frame_type': 'test_known', 'frame_splits': ['singlesequence_teddybear_test_0_known'], 'eval_batch_maps': []}),\n",
       " FrameAnnotation(sequence_name='598_91820_182418', frame_number=101, frame_timestamp=9.990049751243781, image=ImageAnnotation(path='teddybear/598_91820_182418/images/frame000101.jpg', size=(1230, 691)), depth=DepthAnnotation(path='teddybear/598_91820_182418/depths/frame000101.jpg.geometric.png', scale_adjustment=1.0, mask_path='teddybear/598_91820_182418/depth_masks/frame000101.png'), mask=MaskAnnotation(path='teddybear/598_91820_182418/masks/frame000101.png', mass=109388.0), viewpoint=ViewpointAnnotation(R=((0.7249362468719482, -0.5041567087173462, -0.4693543612957001), (0.45325133204460144, -0.16395339369773865, 0.8761749267578125), (-0.5186817049980164, -0.8479064106941223, 0.10965393483638763)), T=(-0.19810901582241058, 1.5162384510040283, 12.337506294250488), focal_length=(3.753972291946411, 3.753972291946411), principal_point=(-0.0, -0.0014471779577434063), intrinsics_format='ndc_isotropic'), meta={'frame_type': 'test_known', 'frame_splits': ['singlesequence_teddybear_test_0_known'], 'eval_batch_maps': []}),\n",
       " FrameAnnotation(sequence_name='598_91820_182418', frame_number=151, frame_timestamp=14.985074626865671, image=ImageAnnotation(path='teddybear/598_91820_182418/images/frame000151.jpg', size=(1230, 691)), depth=DepthAnnotation(path='teddybear/598_91820_182418/depths/frame000151.jpg.geometric.png', scale_adjustment=1.0, mask_path='teddybear/598_91820_182418/depth_masks/frame000151.png'), mask=MaskAnnotation(path='teddybear/598_91820_182418/masks/frame000151.png', mass=156887.0), viewpoint=ViewpointAnnotation(R=((-0.8001500368118286, -0.42043229937553406, -0.42778101563453674), (0.4088907241821289, -0.9041472673416138, 0.12379880249500275), (-0.4388260543346405, -0.07585807144641876, 0.8953643441200256)), T=(0.21985338628292084, 1.0529967546463013, 11.508438110351562), focal_length=(3.7386839389801025, 3.7386839389801025), principal_point=(0.0014471779577434063, -0.0), intrinsics_format='ndc_isotropic'), meta={'frame_type': 'test_known', 'frame_splits': ['singlesequence_teddybear_test_0_known'], 'eval_batch_maps': []})]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "a.get_sequence(\"598_91820_182418\", num_frames=4, determinstic=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c = a.get_camera(\"598_91820_182418\", num_frames=4, determinstic=True)\n",
    "c.device"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "3dgen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
