#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the CC-BY-NC license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3
import os
import pickle
import random
from datetime import datetime
import time

import hydra
import numpy as np
from omegaconf import DictConfig, OmegaConf
import torch
from habitat import logger
from habitat.config import Config
from habitat.config.default import Config as CN
from habitat_vc.config import get_config
from habitat_vc.il.objectnav.custom_baseline_registry import custom_baseline_registry
from habitat_vc.il.objectnav.algos.agent import ILAgent


@hydra.main(config_path="configs", config_name="config_imagenav")
def main(cfg: DictConfig) -> None:
    r"""Main function for habitat_vc
    Args:
        cfg: DictConfig object containing the configs for the experiment.
    """
    run_benchmark(cfg)


def run_benchmark(cfg: DictConfig) -> None:
    r"""Runs experiment given mode and config

    Args:
        cfg: DictConfig object containing the configs for the experiment.

    Returns:
        None.
    """
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg = CN(cfg)

    config = get_config()
    config.merge_from_other_cfg(cfg)
    seed = 0
    print("Using a generated random seed {}".format(seed))
    config.defrost()
    config.TASK_CONFIG.TASK.ANGLE_SUCCESS.USE_TRAIN_SUCCESS = False
    config.TASK_CONFIG.TASK.IMAGEGOAL_ROTATION_SENSOR.SAMPLE_ANGLE = False
    config.TASK_CONFIG.SEED = seed
    config.freeze()
    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)
    torch.manual_seed(config.TASK_CONFIG.SEED)
    if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
        torch.set_num_threads(1)

    setup_experiment(config)

    logger.add_filehandler(config.LOG_FILE)

    benckmark(config)


def benckmark(config: Config) -> None:
    config = config.clone()
    device = torch.device(config.DEVICE)
    data_dir = config.DATA_DIR
    checkpoint_path = config.CHECKPOINT



    il_cfg = config.IL.BehaviorCloning

    config.defrost()
    config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
    config.TASK_CONFIG.ENVIRONMENT.MAX_EPISODE_STEPS = 500
    config.freeze()

    observation_space = pickle.load(open(config.OBSERVATION_SPACE, "rb"))
    action_spaces = pickle.load(open(config.ACTION_SPACE, "rb"))

    policy = custom_baseline_registry.get_policy(config.IL.POLICY.name)
    policy = policy.from_config(
        config, observation_space, action_spaces
    )

    agent = ILAgent(
            model=policy,
            num_envs=100,
            num_mini_batch=il_cfg.num_mini_batch,
            lr=il_cfg.lr,
            encoder_lr=il_cfg.encoder_lr,
            eps=il_cfg.eps,
            wd=il_cfg.wd,
            max_grad_norm=il_cfg.max_grad_norm,
        )

    ckpt_dict = torch.load(checkpoint_path, map_location="cpu")
    # remove running mean from the model
    keys = list()
    for k, v in ckpt_dict["state_dict"].items():
        if 'running_mean_and_var' in k:
            keys.append(k)
    for k in keys:
        del ckpt_dict["state_dict"][k]
    agent.load_state_dict(ckpt_dict["state_dict"], strict=True)
    policy = agent.model
    policy.eval()
    policy.to(device)

    # Load the dataset
    fns = os.listdir(data_dir)
    fns = list(filter(lambda x: x.endswith('.npz'), fns))
    fns = sorted(fns)
    fns = [os.path.join(data_dir, fn) for fn in fns]
    total_time = 0
    total_count = -5
    for fn in fns:
        data = np.load(fn)
        
        batched_test_recurrent_hidden_states = torch.from_numpy(data['test_recurrent_hidden_states']).to(device)
        batched_prev_actions = torch.from_numpy(data['prev_actions']).to(device)
        batched_not_done_masks = torch.from_numpy(data['not_done_masks']).to(device)
        batch_keys = list(data.keys())
        batch_keys = list(filter(lambda x: x not in ['test_recurrent_hidden_states', 'prev_actions', 'not_done_masks'], batch_keys))
        batched_batch = dict()
        for k in batch_keys:
            batched_batch[k] = torch.from_numpy(data[k]).to(device)

        for i in range(len(batched_test_recurrent_hidden_states)):
            test_recurrent_hidden_states = batched_test_recurrent_hidden_states[i]
            prev_actions = batched_prev_actions[i]
            not_done_masks = batched_not_done_masks[i]
            batch = {}
            for k in batch_keys:
                batch[k] = batched_batch[k][i]
            
            start_time = time.perf_counter()
            (
                logits,
                test_recurrent_hidden_states,
                dist_entropy,
            ) = policy(
                batch,
                test_recurrent_hidden_states,
                prev_actions,
                not_done_masks,
            )
            
            # from thop import profile
            # from thop import clever_format
            # import ipdb; ipdb.set_trace()
            # MACs, params = profile(policy, inputs=(batch,test_recurrent_hidden_states,prev_actions,not_done_masks,))
            # MACs, params = clever_format([MACs, params], '%.3f')
            end_time = time.perf_counter()
            T = end_time - start_time
            total_count += 1

            if total_count <= 0:
                continue
            total_time += T
            if total_count % 10 == 0:
                logger.info(
                    f"Processed {total_count} samples, time: {total_time:.6f} seconds, average time: {total_time / total_count:.6f} seconds"
                )
    logger.info(
        f"Processed {total_count} samples, time: {total_time:.6f} seconds, average time: {total_time / total_count:.6f} seconds"
    )



def setup_experiment(config: Config) -> None:
    os.makedirs(config.CHECKPOINT_FOLDER, exist_ok=True)
    os.makedirs(config.VIDEO_DIR, exist_ok=True)
    os.makedirs(config.LOG_DIR, exist_ok=True)

    config.defrost()
    config.TASK_CONFIG.DATASET.SCENES_DIR = hydra.utils.to_absolute_path(
        config.TASK_CONFIG.DATASET.SCENES_DIR
    )
    config.TASK_CONFIG.DATASET.DATA_PATH = hydra.utils.to_absolute_path(
        config.TASK_CONFIG.DATASET.DATA_PATH
    )
    config.freeze()

    os.environ["GLOG_minloglevel"] = "3"
    os.environ["MAGNUM_LOG"] = "quiet"


if __name__ == "__main__":
    main()
