#!/usr/bin/env python3

# Copyright (without_goal+curr_emb) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
from typing import Any, Dict, List, Optional

from habitat.config import Config
from habitat.core.registry import registry
from habitat.datasets.pointnav.pointnav_dataset import (
    CONTENT_SCENES_PATH_FIELD,
    DEFAULT_SCENE_PATH_PREFIX,
    PointNavDatasetV1,
)
from habitat.tasks.nav.nav import (
    NavigationEpisode,
    NavigationGoal,
    ShortestPathPoint,
)

@registry.register_dataset(name="VisTargetNav-v1")
class VisTargetNavDatasetV1(PointNavDatasetV1):
    def __init__(self, config: Optional[Config] = None, filter_fn= None) -> None:
        self.filter_fn = filter_fn
        super().__init__(config)

    def from_json(
        self, json_str: str, scenes_dir: Optional[str] = None
    ) -> None:
        deserialized = json.loads(json_str)
        if CONTENT_SCENES_PATH_FIELD in deserialized:
            self.content_scenes_path = deserialized[CONTENT_SCENES_PATH_FIELD]

        for episode in deserialized["episodes"]:
            episode = NavigationEpisode(**episode)
            if self.filter_fn is not None and not self.filter_fn(episode): continue
            if scenes_dir is not None:
                if episode.scene_id.startswith(DEFAULT_SCENE_PATH_PREFIX):
                    episode.scene_id = episode.scene_id[
                        len(DEFAULT_SCENE_PATH_PREFIX) :
                    ]

                episode.scene_id = os.path.join(scenes_dir, episode.scene_id)

            for g_index, goal in enumerate(episode.goals):
                # NOTE: modification
                processed_goal = {"position": goal["position"], "radius": goal["radius"]}
                
                episode.goals[g_index] = NavigationGoal(**processed_goal)
            #if episode.shortest_paths is not None:
            #    for path in episode.shortest_paths:
            #        for p_index, point in enumerate(path):
            #            path[p_index] = ShortestPathPoint(**point)
            self.episodes.append(episode)

