#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import random
import time
import habitat
import hydra
import wandb
import imageio.v2 as imageio
from omegaconf import DictConfig, OmegaConf
from habitat.config import Config
from habitat.config.default import Config as CN
from habitat_vc.config import get_config
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Queue, Process, Lock

from collections import defaultdict, deque
from typing import Any, DefaultDict, Dict, List, Optional, Type, Union, Tuple

import numpy as np
import torch
import tqdm

from numpy import ndarray
from torch import Tensor

from habitat import logger, Env, RLEnv, VectorEnv, make_dataset
from habitat.core.env import Env, RLEnv
from habitat.core.vector_env import VectorEnv
from habitat.utils import profiling_wrapper
from habitat.utils.visualizations.utils import observations_to_image

from habitat_baselines.common.base_trainer import BaseRLTrainer
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.environments import get_env_class
from habitat_baselines.common.obs_transformers import (
    apply_obs_transforms_batch,
    apply_obs_transforms_obs_space,
    get_active_obs_transforms,
)
from habitat_baselines.common.tensorboard_utils import TensorboardWriter
from habitat_baselines.utils.common import (
    batch_obs,
    generate_video,
    linear_decay,
    get_checkpoint_id,
)
from habitat_baselines.utils.env_utils import construct_envs

from habitat_vc.il.objectnav.algos.agent import ILAgent
from habitat_vc.il.objectnav.rollout_storage import RolloutStorage
from habitat_vc.il.objectnav.custom_baseline_registry import custom_baseline_registry
from habitat_baselines.rl.ddppo.ddp_utils import rank0_only

import habitat_vc.utils as utils


def make_env_fn(
    config: Config, env_class: Union[Type[Env], Type[RLEnv]]
) -> Union[Env, RLEnv]:
    r"""Creates an env of type env_class with specified config and rank.
    This is to be passed in as an argument when creating VectorEnv.

    Args:
        config: root exp config that has core env config node as well as
            env-specific config node.
        env_class: class type of the env to be created.

    Returns:
        env object created according to specification.
    """
    dataset = make_dataset(
        config.TASK_CONFIG.DATASET.TYPE, config=config.TASK_CONFIG.DATASET
    )
    env = env_class(config=config, dataset=dataset)
    env.seed(config.TASK_CONFIG.SEED)
    return env


def construct_envs_args(
    config: Config,
    env_class: Union[Type[Env], Type[RLEnv]],
    workers_ignore_signals: bool = False,
    gpu_ids: Optional[List[int]] = None,
) -> VectorEnv:
    r"""Create VectorEnv object with specified config and env class type.
    To allow better performance, dataset are split into small ones for
    each individual env, grouped by scenes.

    :param config: configs that contain num_environments as well as information
    :param necessary to create individual environments.
    :param env_class: class type of the envs to be created.
    :param workers_ignore_signals: Passed to :ref:`habitat.VectorEnv`'s constructor

    :return: VectorEnv object created according to specification.
    """

    if gpu_ids is None:
        gpu_ids = [config.SIMULATOR_GPU_ID]
    num_environments = len(gpu_ids)
    gpu_ids = [int(gpu_id) for gpu_id in gpu_ids]
    
    configs = []
    env_classes = [env_class for _ in range(num_environments)]
    dataset = make_dataset(config.TASK_CONFIG.DATASET.TYPE)
    scenes = config.TASK_CONFIG.DATASET.CONTENT_SCENES
    if "*" in config.TASK_CONFIG.DATASET.CONTENT_SCENES:
        scenes = dataset.get_scenes_to_load(config.TASK_CONFIG.DATASET)

    if num_environments > 1:
        if len(scenes) == 0:
            raise RuntimeError(
                "No scenes to load, multiple process logic relies on being able to split scenes uniquely between processes"
            )

        if len(scenes) < num_environments:
            raise RuntimeError(
                "reduce the number of environments as there "
                "aren't enough number of scenes.\n"
                "num_environments: {}\tnum_scenes: {}".format(
                    num_environments, len(scenes)
                )
            )

        random.shuffle(scenes)

    scene_splits: List[List[str]] = [[] for _ in range(num_environments)]
    for idx, scene in enumerate(scenes):
        scene_splits[idx % len(scene_splits)].append(scene)

    assert sum(map(len, scene_splits)) == len(scenes)

    for i in range(num_environments):
        proc_config = config.clone()
        proc_config.defrost()

        task_config = proc_config.TASK_CONFIG
        task_config.SEED = task_config.SEED + i
        if len(scenes) > 0:
            task_config.DATASET.CONTENT_SCENES = scene_splits[i]

        task_config.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = (
            gpu_ids[i]
        )

        task_config.SIMULATOR.AGENT_0.SENSORS = config.SENSORS

        proc_config.freeze()
        configs.append(proc_config)

    build_args = [
        {
            "make_env_fn": make_env_fn,
            "env_fn_args": [(configs[i], env_classes[i])],
            "workers_ignore_signals": workers_ignore_signals,
        }
        for i in range(num_environments)
    ]
    return build_args
    
    # envs = habitat.VectorEnv(
    #     make_env_fn=make_env_fn,
    #     env_fn_args=tuple(zip(configs, env_classes)),
    #     workers_ignore_signals=workers_ignore_signals,
    #     auto_reset_done=False,
    # )


class Collector:
    def __init__(self, env_args, queue, npz_dir, video_dir, start_lock) -> None:
        self.env_args = env_args
        self.queue = queue
        self.npz_dir = npz_dir
        self.video_dir = video_dir
        self.start_lock = start_lock

    def save_agent_stats(self, episode_id, agent_stats):
        # save the video
        video_path = os.path.join(self.video_dir, episode_id + ".mp4")
        writer = imageio.get_writer(video_path,
                                    fps=5,
                                    codec='libx264',
                                    format='mp4',
                                    quality=8,
                                    ffmpeg_log_level='error')
        for frame in agent_stats["rgb"]:
            writer.append_data(frame)
        writer.close()
        # save the npz file
        # convert the list of arrays to ndarrays
        for sensor, value in agent_stats.items():
            agent_stats[sensor] = np.array(value)
        npz_path = os.path.join(self.npz_dir, episode_id + ".npz")
        np.savez_compressed(npz_path, **agent_stats)
        # send the episode id to the main process
        self.queue.put(episode_id)
        return

    def __call__(self) -> None:
        thread_pool = ThreadPoolExecutor(max_workers=10)
        self.envs = habitat.VectorEnv(auto_reset_done=False, **self.env_args)
        number_of_eval_episodes = sum(self.envs.number_of_episodes)
        self.queue.put(number_of_eval_episodes)
        self.start_lock.acquire()
        self.start_lock.release()

        collected_episodes = set()

        done = True
        while True: # self.envs.num_envs > 0
            if done:
                observations = self.envs.reset()
                done = False

                # check if the episode is already collected
                current_episodes = self.envs.current_episodes()
                if current_episodes[0].episode_id in collected_episodes:
                    break
                
                obs = observations[0]
                sensor_names = list(obs.keys())
                agent_stats = dict()
                for sensor in sensor_names:
                    agent_stats[sensor] = list()
            else:
                next_action = obs["demonstration"]
                step_data = [next_action]

                outputs = self.envs.step(step_data)
                observations, rewards_l, dones, infos = [list(x) for x in zip(*outputs)]
                done = dones[0]
                obs = observations[0]
            
            for sensor, value in obs.items():
                agent_stats[sensor].append(value)

            if done:
                current_episodes = self.envs.current_episodes()
                episode_id = current_episodes[0].episode_id
                collected_episodes.add(episode_id)
                
                # save the agent stats with a thread
                thread_pool.submit(self.save_agent_stats, episode_id, agent_stats)
        
        # wait for all threads to finish
        thread_pool.shutdown(wait=True)
        self.envs.close()
        return


@hydra.main(config_path="configs", config_name="config_imagenav")
def main(cfg: DictConfig) -> None:
    r"""Main function for habitat_vc
    Args:
        config: DictConfig object containing the configs for the experiment.
    """
    
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg = CN(cfg)

    config = get_config()
    config.merge_from_other_cfg(cfg)
    
    os.makedirs(config.VIDEO_DIR, exist_ok=True)
    config.defrost()
    config.TASK_CONFIG.DATASET.SCENES_DIR = hydra.utils.to_absolute_path(
        config.TASK_CONFIG.DATASET.SCENES_DIR
    )
    config.TASK_CONFIG.DATASET.DATA_PATH = hydra.utils.to_absolute_path(
        config.TASK_CONFIG.DATASET.DATA_PATH
    )
    config.freeze()
    
    # Set up the environment
    
    envs_args = construct_envs_args(
        config=config,
        env_class=get_env_class(config.ENV_NAME),
        workers_ignore_signals=False,
        gpu_ids=config.GPU_IDS,
    )

    queue = Queue()
    os.makedirs(config.VIDEO_DIR, exist_ok=True)
    os.makedirs(config.NPZ_DIR, exist_ok=True)

    start_lock = Lock()
    start_lock.acquire()
    processes = []
    for env_args in envs_args:
        collector = Collector(
            env_args=env_args,
            queue=queue,
            npz_dir=config.NPZ_DIR,
            video_dir=config.VIDEO_DIR,
            start_lock=start_lock,
        )
        p = Process(target=collector)
        p.start()
        processes.append(p)
    
    num_episodes = 0
    for i in range(len(processes)):
        num_episodes += queue.get()
    print(f"Number of episodes: {num_episodes}")
    print(f"Number of processes: {len(processes)}")
    pbar = tqdm.tqdm(total=num_episodes)
    start_lock.release()

    while num_episodes > 0:
        episode_id = queue.get()
        if episode_id is None:
            continue
        pbar.update(1)
        num_episodes -= 1
    
    for p in processes:
        p.join()
    
    # debug with single env
    # env_args = envs_args[0]
    # collector = Collector(
    #     env_args=env_args,
    #     queue=queue,
    #     npz_dir=config.NPZ_DIR,
    #     video_dir=config.VIDEO_DIR,
    # )
    # collector()


if __name__ == "__main__":
    os.environ["LD_LIBRARY_PATH"] = (
        "/usr/lib/x86_64-linux-gnu/nvidia-opengl:" + os.environ["LD_LIBRARY_PATH"]
    )
    os.environ["GLOG_minloglevel"] = "3"
    os.environ["MAGNUM_LOG"] = "quiet"

    main()