#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Union

import numpy as np

from yacs.config import CfgNode as CN

DEFAULT_CONFIG_DIR = "configs/"
CONFIG_FILE_SEPARATOR = ","


def get_task_config(
    config_paths: Optional[Union[List[str], str]] = None,
    opts: Optional[list] = None,
) -> CN:
    r"""Create a unified config with default values overwritten by values from
    `config_paths` and overwritten by options from `opts`.
    Args:
        config_paths: List of config paths or string that contains comma
        separated list of config paths.
        opts: Config options (keys, values) in a list (e.g., passed from
        command line into the config. For example, `opts = ['FOO.BAR',
        0.5]`. Argument can be used for parameter sweeping or quick tests.
    """
    # -----------------------------------------------------------------------------
    # Config definition
    # -----------------------------------------------------------------------------
    _C = CN()
    _C.SEED = 100
    # -----------------------------------------------------------------------------
    # ENVIRONMENT
    # -----------------------------------------------------------------------------
    _C.ENVIRONMENT = CN()
    _C.ENVIRONMENT.MAX_EPISODE_STEPS = 1000
    _C.ENVIRONMENT.MAX_EPISODE_SECONDS = 10000000
    _C.ENVIRONMENT.ITERATOR_OPTIONS = CN()
    _C.ENVIRONMENT.ITERATOR_OPTIONS.CYCLE = True
    _C.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
    _C.ENVIRONMENT.ITERATOR_OPTIONS.GROUP_BY_SCENE = True
    _C.ENVIRONMENT.ITERATOR_OPTIONS.NUM_EPISODE_SAMPLE = -1
    _C.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT = -1
    # -----------------------------------------------------------------------------
    # TASK
    # -----------------------------------------------------------------------------
    _C.TASK = CN()
    _C.TASK.TYPE = "Nav-v0"
    _C.TASK.SUCCESS_DISTANCE = 0.2
    _C.TASK.SENSORS = []
    _C.TASK.MEASUREMENTS = []
    _C.TASK.GOAL_SENSOR_UUID = "pointgoal"
    # -----------------------------------------------------------------------------
    # # POINTGOAL SENSOR
    # -----------------------------------------------------------------------------
    _C.TASK.POINTGOAL_SENSOR = CN()
    _C.TASK.POINTGOAL_SENSOR.TYPE = "PointGoalSensor"
    _C.TASK.POINTGOAL_SENSOR.GOAL_FORMAT = "POLAR"
    _C.TASK.POINTGOAL_SENSOR.DIMENSIONALITY = 2
    # -----------------------------------------------------------------------------
    # # POINTGOAL WITH GPS+COMPASS SENSOR
    # -----------------------------------------------------------------------------
    _C.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR = _C.TASK.POINTGOAL_SENSOR.clone()
    _C.TASK.POINTGOAL_WITH_GPS_COMPASS_SENSOR.TYPE = (
        "PointGoalWithGPSCompassSensor"
    )
    # -----------------------------------------------------------------------------
    # # HEADING SENSOR
    # -----------------------------------------------------------------------------
    _C.TASK.HEADING_SENSOR = CN()
    _C.TASK.HEADING_SENSOR.TYPE = "HeadingSensor"
    # -----------------------------------------------------------------------------
    # # COMPASS SENSOR
    # -----------------------------------------------------------------------------
    _C.TASK.COMPASS_SENSOR = CN()
    _C.TASK.COMPASS_SENSOR.TYPE = "CompassSensor"
    # -----------------------------------------------------------------------------
    # # GPS SENSOR
    # -----------------------------------------------------------------------------
    _C.TASK.GPS_SENSOR = CN()
    _C.TASK.GPS_SENSOR.TYPE = "GPSSensor"
    _C.TASK.GPS_SENSOR.DIMENSIONALITY = 2
    # -----------------------------------------------------------------------------
    # # PROXIMITY SENSOR
    # -----------------------------------------------------------------------------
    _C.TASK.PROXIMITY_SENSOR = CN()
    _C.TASK.PROXIMITY_SENSOR.TYPE = "ProximitySensor"
    _C.TASK.PROXIMITY_SENSOR.MAX_DETECTION_RADIUS = 2.0
    # -----------------------------------------------------------------------------
    # # SPL MEASUREMENT
    # -----------------------------------------------------------------------------
    _C.TASK.SPL = CN()
    _C.TASK.SPL.TYPE = "SPL"
    _C.TASK.SPL.SUCCESS_DISTANCE = 0.2
    # -----------------------------------------------------------------------------
    # # TopDownMap MEASUREMENT
    # -----------------------------------------------------------------------------
    _C.TASK.TOP_DOWN_MAP = CN()
    _C.TASK.TOP_DOWN_MAP.TYPE = "TopDownMap"
    _C.TASK.TOP_DOWN_MAP.MAX_EPISODE_STEPS = _C.ENVIRONMENT.MAX_EPISODE_STEPS
    _C.TASK.TOP_DOWN_MAP.MAP_PADDING = 3
    _C.TASK.TOP_DOWN_MAP.NUM_TOPDOWN_MAP_SAMPLE_POINTS = 20000
    _C.TASK.TOP_DOWN_MAP.MAP_RESOLUTION = 1250
    _C.TASK.TOP_DOWN_MAP.DRAW_SOURCE_AND_TARGET = True
    _C.TASK.TOP_DOWN_MAP.DRAW_BORDER = True
    _C.TASK.TOP_DOWN_MAP.DRAW_SHORTEST_PATH = True
    _C.TASK.TOP_DOWN_MAP.FOG_OF_WAR = CN()
    _C.TASK.TOP_DOWN_MAP.FOG_OF_WAR.DRAW = True
    _C.TASK.TOP_DOWN_MAP.FOG_OF_WAR.VISIBILITY_DIST = 5.0
    _C.TASK.TOP_DOWN_MAP.FOG_OF_WAR.FOV = 90
    # -----------------------------------------------------------------------------
    # # COLLISIONS MEASUREMENT
    # -----------------------------------------------------------------------------
    _C.TASK.COLLISIONS = CN()
    _C.TASK.COLLISIONS.TYPE = "Collisions"
    # -----------------------------------------------------------------------------
    # SIMULATOR
    # -----------------------------------------------------------------------------
    _C.SIMULATOR = CN()
    _C.SIMULATOR.TYPE = "Sim-v0"
    _C.SIMULATOR.ACTION_SPACE_CONFIG = "v0"
    _C.SIMULATOR.FORWARD_STEP_SIZE = 0.25  # in metres
    _C.SIMULATOR.SCENE = (
        "data/scene_datasets/habitat-test-scenes/" "van-gogh-room.glb"
    )
    _C.SIMULATOR.SEED = _C.SEED
    _C.SIMULATOR.TURN_ANGLE = 10  # angle to rotate left or right in degrees
    _C.SIMULATOR.TILT_ANGLE = 15  # angle to tilt the camera up or down in degrees
    _C.SIMULATOR.DEFAULT_AGENT_ID = 0
    # -----------------------------------------------------------------------------
    # # SENSORS
    # -----------------------------------------------------------------------------
    SENSOR = CN()
    SENSOR.HEIGHT = 480
    SENSOR.WIDTH = 640
    SENSOR.HFOV = 90  # horizontal field of view in degrees
    SENSOR.POSITION = [0, 1.25, 0]
    # -----------------------------------------------------------------------------
    # # RGB SENSOR
    # -----------------------------------------------------------------------------
    _C.SIMULATOR.RGB_SENSOR = SENSOR.clone()
    _C.SIMULATOR.RGB_SENSOR.TYPE = "HabitatSimRGBSensor"
    # -----------------------------------------------------------------------------
    # DEPTH SENSOR
    # -----------------------------------------------------------------------------
    _C.SIMULATOR.DEPTH_SENSOR = SENSOR.clone()
    _C.SIMULATOR.DEPTH_SENSOR.TYPE = "HabitatSimDepthSensor"
    _C.SIMULATOR.DEPTH_SENSOR.MIN_DEPTH = 0
    _C.SIMULATOR.DEPTH_SENSOR.MAX_DEPTH = 10
    _C.SIMULATOR.DEPTH_SENSOR.NORMALIZE_DEPTH = True
    # -----------------------------------------------------------------------------
    # SEMANTIC SENSOR
    # -----------------------------------------------------------------------------
    _C.SIMULATOR.SEMANTIC_SENSOR = SENSOR.clone()
    _C.SIMULATOR.SEMANTIC_SENSOR.TYPE = "HabitatSimSemanticSensor"
    # -----------------------------------------------------------------------------
    # AGENT
    # -----------------------------------------------------------------------------
    _C.SIMULATOR.AGENT_0 = CN()
    _C.SIMULATOR.AGENT_0.HEIGHT = 1.5
    _C.SIMULATOR.AGENT_0.RADIUS = 0.1
    _C.SIMULATOR.AGENT_0.MASS = 32.0
    _C.SIMULATOR.AGENT_0.LINEAR_ACCELERATION = 20.0
    _C.SIMULATOR.AGENT_0.ANGULAR_ACCELERATION = 4 * 3.14
    _C.SIMULATOR.AGENT_0.LINEAR_FRICTION = 0.5
    _C.SIMULATOR.AGENT_0.ANGULAR_FRICTION = 1.0
    _C.SIMULATOR.AGENT_0.COEFFICIENT_OF_RESTITUTION = 0.0
    _C.SIMULATOR.AGENT_0.SENSORS = ["RGB_SENSOR"]
    _C.SIMULATOR.AGENT_0.IS_SET_START_STATE = False
    _C.SIMULATOR.AGENT_0.START_POSITION = [0, 0, 0]
    _C.SIMULATOR.AGENT_0.START_ROTATION = [0, 0, 0, 1]
    _C.SIMULATOR.AGENTS = ["AGENT_0"]
    # -----------------------------------------------------------------------------
    # SIMULATOR HABITAT_SIM_V0
    # -----------------------------------------------------------------------------
    _C.SIMULATOR.HABITAT_SIM_V0 = CN()
    _C.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = 0
    # -----------------------------------------------------------------------------
    # DATASET
    # -----------------------------------------------------------------------------
    _C.DATASET = CN()
    _C.DATASET.TYPE = "PointNav-v1"
    _C.DATASET.SPLIT = "train"
    _C.DATASET.SCENES_DIR = "data/scene_datasets"
    _C.DATASET.CONTENT_SCENES = ["*"]
    _C.DATASET.DATA_PATH = (
        "data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
    )

    config = _C.clone()
    
    if config_paths:
        if isinstance(config_paths, str):
            if CONFIG_FILE_SEPARATOR in config_paths:
                config_paths = config_paths.split(CONFIG_FILE_SEPARATOR)
            else:
                config_paths = [config_paths]

        for config_path in config_paths:
            config.merge_from_file(config_path)

    if opts:
        config.merge_from_list(opts)

    config.freeze()
    return config

# -----------------------------------------------------------------------------
# EXPERIMENT CONFIG
# -----------------------------------------------------------------------------
_C = CN()
_C.BASE_TASK_CONFIG_PATH = "./env/habitat/habitat_api/configs/tasks/pointnav.yaml"
_C.TASK_CONFIG = CN()  # task_config will be stored as a config node
_C.CMD_TRAILING_OPTS = []  # store command line options as list of strings
_C.TRAINER_NAME = "ppo"
_C.ENV_NAME = "NavRLEnv"
_C.SIMULATOR_GPU_ID = 0
_C.TORCH_GPU_ID = 0
_C.VIDEO_OPTION = ["disk", "tensorboard"]
_C.TENSORBOARD_DIR = "tb"
_C.VIDEO_DIR = "video_dir"
_C.TEST_EPISODE_COUNT = 2
_C.EVAL_CKPT_PATH_DIR = "data/checkpoints"  # path to ckpt or path to ckpts dir
_C.NUM_PROCESSES = 16
_C.SENSORS = ["RGB_SENSOR", "DEPTH_SENSOR"]
_C.CHECKPOINT_FOLDER = "data/checkpoints"
_C.NUM_UPDATES = 10000
_C.LOG_INTERVAL = 10
_C.LOG_FILE = "train.log"
_C.CHECKPOINT_INTERVAL = 50
# -----------------------------------------------------------------------------
# REINFORCEMENT LEARNING (RL) ENVIRONMENT CONFIG
# -----------------------------------------------------------------------------
_C.RL = CN()
_C.RL.SUCCESS_REWARD = 10.0
_C.RL.SLACK_REWARD = -0.01
# -----------------------------------------------------------------------------
# PROXIMAL POLICY OPTIMIZATION (PPO)
# -----------------------------------------------------------------------------
_C.RL.PPO = CN()
_C.RL.PPO.clip_param = 0.2
_C.RL.PPO.ppo_epoch = 4
_C.RL.PPO.num_mini_batch = 16
_C.RL.PPO.value_loss_coef = 0.5
_C.RL.PPO.entropy_coef = 0.01
_C.RL.PPO.lr = 7e-4
_C.RL.PPO.eps = 1e-5
_C.RL.PPO.max_grad_norm = 0.5
_C.RL.PPO.num_steps = 5
_C.RL.PPO.hidden_size = 512
_C.RL.PPO.use_gae = True
_C.RL.PPO.use_linear_lr_decay = False
_C.RL.PPO.use_linear_clip_decay = False
_C.RL.PPO.gamma = 0.99
_C.RL.PPO.tau = 0.95
_C.RL.PPO.reward_window_size = 50
# -----------------------------------------------------------------------------
# ORBSLAM2 BASELINE
# -----------------------------------------------------------------------------
_C.ORBSLAM2 = CN()
_C.ORBSLAM2.SLAM_VOCAB_PATH = "habitat_baselines/slambased/data/ORBvoc.txt"
_C.ORBSLAM2.SLAM_SETTINGS_PATH = (
    "habitat_baselines/slambased/data/mp3d3_small1k.yaml"
)
_C.ORBSLAM2.MAP_CELL_SIZE = 0.1
_C.ORBSLAM2.MAP_SIZE = 40
_C.ORBSLAM2.CAMERA_HEIGHT = get_task_config().SIMULATOR.DEPTH_SENSOR.POSITION[
    1
]
_C.ORBSLAM2.BETA = 100
_C.ORBSLAM2.H_OBSTACLE_MIN = 0.3 * _C.ORBSLAM2.CAMERA_HEIGHT
_C.ORBSLAM2.H_OBSTACLE_MAX = 1.0 * _C.ORBSLAM2.CAMERA_HEIGHT
_C.ORBSLAM2.D_OBSTACLE_MIN = 0.1
_C.ORBSLAM2.D_OBSTACLE_MAX = 4.0
_C.ORBSLAM2.PREPROCESS_MAP = True
_C.ORBSLAM2.MIN_PTS_IN_OBSTACLE = (
    get_task_config().SIMULATOR.DEPTH_SENSOR.WIDTH / 2.0
)
_C.ORBSLAM2.ANGLE_TH = float(np.deg2rad(15))
_C.ORBSLAM2.DIST_REACHED_TH = 0.15
_C.ORBSLAM2.NEXT_WAYPOINT_TH = 0.5
_C.ORBSLAM2.NUM_ACTIONS = 3
_C.ORBSLAM2.DIST_TO_STOP = 0.05
_C.ORBSLAM2.PLANNER_MAX_STEPS = 500
_C.ORBSLAM2.DEPTH_DENORM = get_task_config().SIMULATOR.DEPTH_SENSOR.MAX_DEPTH


def get_config(
    config_paths: Optional[Union[List[str], str]] = None,
    opts: Optional[list] = None,
) -> CN:
    r"""Create a unified config with default values overwritten by values from
    `config_paths` and overwritten by options from `opts`.
    Args:
        config_paths: List of config paths or string that contains comma
        separated list of config paths.
        opts: Config options (keys, values) in a list (e.g., passed from
        command line into the config. For example, `opts = ['FOO.BAR',
        0.5]`. Argument can be used for parameter sweeping or quick tests.
    """
    config = _C.clone()
    if config_paths:
        if isinstance(config_paths, str):
            if CONFIG_FILE_SEPARATOR in config_paths:
                config_paths = config_paths.split(CONFIG_FILE_SEPARATOR)
            else:
                config_paths = [config_paths]

        for config_path in config_paths:
            config.merge_from_file(config_path)

    config.TASK_CONFIG = get_task_config(config.BASE_TASK_CONFIG_PATH)
    if opts:
        config.CMD_TRAILING_OPTS = opts
        config.merge_from_list(opts)

    config.freeze()
    return config
