# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import glob
from omegaconf import DictConfig
from typing import Optional

import torch

from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
    JsonIndexDatasetMapProviderV2
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.io import IO
from pytorch3d.renderer import (
    NDCMultinomialRaysampler,
    ray_bundle_to_ray_points,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds


HOLDOUT_CATEGORIES = set([
    'apple',
    'baseballglove',
    'cup',
    'ball',
    'toyplane',
    'handbag',
    'book',
    'carrot',
    'suitcase',
    'bowl',
])

def get_dataset_map(
    dataset_root: str,
    category: str,
    subset_name: str,
) -> DatasetMap:
    """
    Obtain the dataset map that contains the train/val/test dataset objects.
    """
    expand_args_fields(JsonIndexDatasetMapProviderV2)
    dataset_map_provider = JsonIndexDatasetMapProviderV2(
        category=category,
        subset_name=subset_name,
        dataset_root=dataset_root,
        test_on_train=False,
        only_test_set=False,
        load_eval_batches=True,
        dataset_JsonIndexDataset_args=DictConfig({"remove_empty_masks": False, "load_point_clouds": False}),
    )
    return dataset_map_provider.get_dataset_map()


def _load_pointcloud(pcl_path, max_points):
    pcl = IO().load_pointcloud(pcl_path)
    if max_points > 0:
        pcl = pcl.subsample(max_points)

    return pcl


def get_all_dataset_maps(co3d_path, holdout_categories, one_class = ''):
    all_categories = [c.split('/')[-1] for c in list(glob.glob(co3d_path + '/*')) if not c.endswith('.json')]
    all_categories = sorted(all_categories, key=lambda x: hash(x))

    # Obtain the CO3Dv2 dataset map
    train_dataset_maps = {}
    val_dataset_maps = {}
    if one_class != '':
        all_categories = [one_class]
    for category in all_categories:

        print(f'Loading dataset map ({category})')
        dataset_map = {
            'train': torch.load(f'dataset_cache/{category}_train.pt'),
            'val': torch.load(f'dataset_cache/{category}_val.pt')
        }
        if not holdout_categories or category not in HOLDOUT_CATEGORIES:
            train_dataset_maps[category] = dataset_map['train']
        if not holdout_categories or category in HOLDOUT_CATEGORIES:
            val_dataset_maps[category] = dataset_map['val']

    print('Loaded', len(train_dataset_maps), 'categores for train')
    print('Loaded', len(val_dataset_maps), 'categores for val')
    return train_dataset_maps, val_dataset_maps


def get_rgbd_points(
    imh, imw,
    camera: CamerasBase,
    depth_map: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    mask_thr: float = 0.5,
) -> Pointclouds:
    """
    Given a batch of images, depths, masks and cameras, generate a colored
    point cloud by unprojecting depth maps to the  and coloring with the source
    pixel colors.
    """
    depth_map = torch.nn.functional.interpolate(
        depth_map,
        size=[imh, imw],
        mode="bilinear",
        align_corners=False,
    )
    # convert the depth maps to point clouds using the grid ray sampler
    pts_3d = ray_bundle_to_ray_points(
        NDCMultinomialRaysampler(
            image_width=imw,
            image_height=imh,
            n_pts_per_ray=1,
            min_depth=1.0,
            max_depth=1.0,
        )(camera)._replace(lengths=depth_map[:, 0, ..., None])
    ).squeeze(3)[None]

    pts_mask = depth_map > 0.0
    if mask is not None:
        mask = torch.nn.functional.interpolate(
            mask,
            size=[imh, imw],
            mode="bilinear",
            align_corners=False,
        )
        pts_mask *= mask > mask_thr
    pts_3d[~pts_mask] = float('inf')
    return pts_3d.squeeze(0).squeeze(0)

