# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license 
from glob import glob
import json
import os
import time
from collections import defaultdict
from enum import Enum
import traceback
from typing import TYPE_CHECKING, Any, Dict, Optional

import numpy as np
import pandas as pd
from habitat_baselines.rl.ppo.ppo_trainer import PPOTrainer
from habitat_baselines.utils.info_dict import extract_scalars_from_info
from omegaconf import DictConfig
from tqdm import tqdm
from natsort import natsorted

from helios.utils.env_utils import create_priv_env_fn
from helios.utils.metrics_utils import get_stats_from_episode_metrics
from helios.evaluator.ovmm_evaluator import OVMMEvaluator

if TYPE_CHECKING:
    from habitat.core.dataset import BaseEpisode
    from habitat.core.vector_env import VectorEnv

    from home_robot.agent.ovmm_agent.ovmm_agent import OpenVocabManipAgent
    from home_robot.core.abstract_agent import Agent

from home_robot.agent.ovmm_agent.ovmm_agent import Skill

class EvaluationType(Enum):
    """Whether we run local or remote evaluation."""

    LOCAL = "local"
    LOCAL_VECTORIZED = "local_vectorized"
    REMOTE = "remote"


class HELIOSEvaluator(OVMMEvaluator):
    """Class for creating vectorized environments, evaluating OpenVocabManipAgent on an episode dataset and returning metrics."""

    def __init__(
        self,
        eval_config: DictConfig,
        visualize: bool,
        start_episode: int,
        catch_exceptions: bool,
    ):
        super().__init__(eval_config)
        self.visualize = visualize
        self.start_episode = start_episode
        self.catch_exceptions = catch_exceptions

    def local_evaluate(
        self, agent: "Agent", num_episodes: Optional[int] = None
    ) -> Dict[str, float]:
        """
        Evaluates the agent in the local environment.

        :param agent: agent to be evaluated in environment.
        :param num_episodes: count of number of episodes for which the evaluation should be run.
        :return: dict containing metrics tracked by environment.
        """

        env_num_episodes = self._env.number_of_episodes
        if num_episodes is None:
            num_episodes = env_num_episodes
        else:
            assert num_episodes <= env_num_episodes - self.start_episode, (
                "num_episodes({}) is larger than number of episodes "
                "in environment ({})".format(num_episodes, env_num_episodes)
            )

        assert num_episodes > 0, "num_episodes should be greater than 0"

        episode_metrics: Dict = {}
        max_picks = 50
        episode_metrics_per_pick: Dict = {}
        for mp in range(1,max_picks):
            episode_metrics_per_pick[mp] = {}

        pbar = tqdm(total=num_episodes)
        if hasattr(agent, 'set_envs'):
            agent.set_envs(self._env)
        if self.start_episode > 0:
            for _ in range(self.start_episode):
                self._env.reset()
        
        exceptions = {}
        for i_episode in tqdm(list(range(num_episodes)), desc="Episodes Evaluated"):
            observations, done = self._env.reset(), False
            current_episode = self._env.get_current_episode()
            agent.reset()
            self._check_set_planner_vis_dir(agent, current_episode)

            current_episode_key = (
                f"{current_episode.scene_id.split('/')[-1].split('.')[0]}_"
                f"{current_episode.episode_id}"
            )
            if hasattr(agent, 'set_episode_key'):
                agent.set_episode_key(episode_key=current_episode_key)
            current_episode_metrics = {}
            obs_data = [observations]
            hab_info = {}
            in_pick_last_step = False
            n_picks = 0
            steps_per_pick = {}
            for _ in tqdm(list(range(self.config.habitat.environment.max_episode_steps)), desc="Steps"):
                if self.catch_exceptions:
                    try:
                        action, info, _ = agent.act(observations)
                        observations, done, hab_info = self._env.apply_action(action, info)
                    except Exception:
                        e = traceback.format_exc()
                        print(f"Exception caught: {e}")
                        exceptions[current_episode_key] = f'{e}'
                        break
                else:
                    action, info, _ = agent.act(observations)
                    observations, done, hab_info = self._env.apply_action(action, info)

                metrics = extract_scalars_from_info(hab_info)

                if agent.states[0] == Skill.PICK:
                    if not in_pick_last_step:
                        n_picks += 1
                        steps_per_pick[n_picks] = agent.timesteps[0]
                    in_pick_last_step = True
                else:
                    in_pick_last_step = False

                
                # print("COLLISIONS: ", metrics['robot_collisions.total_collisions'], metrics['robot_collisions.robot_obj_colls'], 
                #       metrics['robot_collisions.robot_scene_colls'], metrics['robot_collisions.obj_scene_colls'])
                if self.data_dir:
                    obs_data.append(observations)
                if "skill_done" in info and info["skill_done"] != "":
                    metrics = extract_scalars_from_info(hab_info)
                    metrics_at_skill_end = {
                        f"{info['skill_done']}." + k: v for k, v in metrics.items()
                    }
                    current_episode_metrics = {
                        **metrics_at_skill_end,
                        **current_episode_metrics
                    }
                    if "goal_name" in info:
                        current_episode_metrics["goal_name"] = info["goal_name"]
                if done:
                    break

            if len(hab_info):
                metrics = extract_scalars_from_info(hab_info)
                metrics_at_episode_end = {"END." + k: v for k, v in metrics.items()}
                # ku: add failure summary
                summary_names = [
                    '0_failed_find_object',
                    '1_failed_pick_object',
                    '2_failed_find_recep',
                    '3_failed_place_goal',
                    '4_failed_place_rest',
                    '5_failed place_stable',
                    '6_collision',
                    '7_success',
                    '8_undefined',
                ]
                if not metrics_at_episode_end['END.ovmm_find_object_phase_success']:
                    summary = '0_failed_find_object'
                elif not metrics_at_episode_end['END.ovmm_pick_object_phase_success']:
                    summary = '1_failed_pick_object'
                elif not metrics_at_episode_end['END.ovmm_find_recep_phase_success']:
                    summary = '2_failed_find_recep'
                elif 'END.obj_anywhere_on_goal.0' not in metrics_at_episode_end:
                    summary = '3_failed_place_goal'
                elif not metrics_at_episode_end['END.obj_anywhere_on_goal.0']:
                    summary = '3_failed_place_goal'
                elif not metrics_at_episode_end['END.object_at_rest']:
                    summary = '4_failed_place_rest'
                elif not metrics_at_episode_end['END.ovmm_placement_stability']:
                    summary = '5_failed_place_stable'
                elif metrics_at_episode_end['END.ovmm_place_success'] and metrics_at_episode_end['END.robot_collisions.robot_scene_colls']:
                    summary = '6_collision'
                elif metrics_at_episode_end['END.ovmm_place_success'] and not metrics_at_episode_end['END.robot_collisions.robot_scene_colls']:
                    summary = '7_success'
                else:
                    summary = '8_undefined'

                current_episode_metrics = {
                    **{summary_name: 1.0 if summary_name == summary else 0.0 for summary_name in summary_names},
                    **metrics_at_episode_end,
                    **current_episode_metrics
                }
                
                if self.visualize and hasattr(agent, 'generate_video'):
                    video_dir = os.path.join(self.config.habitat_baselines.video_dir, summary)
                    agent.generate_video(
                        dir=video_dir,
                        i_episode=self.start_episode+i_episode,
                        current_episode_key=current_episode_key,
                    )

                if self.data_dir:
                    import pickle

                    data_episode_path = os.path.join(self.data_dir, current_episode_key)
                    os.makedirs(data_episode_path, exist_ok=True)
                    with open(os.path.join(data_episode_path, "obs_data.pkl"), "wb") as f:
                        pickle.dump(obs_data, f)

                if "goal_name" in info:
                    current_episode_metrics["goal_name"] = info["goal_name"]

                current_episode_metrics["n_picks"] = float(n_picks)

                steps_per_pick[-1] = current_episode_metrics["END.num_steps"]

                episode_metrics[current_episode_key] = current_episode_metrics
                episode_metrics_per_pick = self.write_metrics_n_picks(n_picks, metrics_at_episode_end, episode_metrics_per_pick, max_picks, current_episode_key, steps_per_pick)
                if len(episode_metrics) % self.metrics_save_freq == 0:
                    aggregated_metrics = self._aggregate_metrics(episode_metrics)
                    self._write_results(episode_metrics, aggregated_metrics)

                    average_metrics = self._summarize_metrics(episode_metrics)
                    self._print_summary(average_metrics)
                    with open(f"{self.results_dir}/average_metrics.json", "w") as f:
                        json.dump(average_metrics, f, indent=4)

                    with open(f"{self.results_dir}/exceptions.json", "w") as f:
                        json.dump(exceptions, f, indent=4)

            pbar.update(1)
        agent.reset()
        if hasattr(agent, "end_eval"):
            agent.end_eval()
        self._env.close()
        
        aggregated_metrics = self._aggregate_metrics(episode_metrics)
        self._write_results(episode_metrics, aggregated_metrics)

        average_metrics = self._summarize_metrics(episode_metrics)
        self._print_summary(average_metrics)
        with open(f"{self.results_dir}/average_metrics.json", "w") as f:
            json.dump(average_metrics, f, indent=4)

        print(exceptions)
        with open(f"{self.results_dir}/exceptions.json", "w") as f:
            json.dump(exceptions, f, indent=4)

        return average_metrics
    
    def write_metrics_n_picks(self, n_picks, metrics_at_episode_end, episode_metrics_per_pick, max_picks, current_episode_key, steps_per_pick):
        with open(f"{self.results_dir}/average_metrics_per_pick.json", "w") as f:
            for n_pick_k in range(1,max_picks):
                if n_pick_k+1 in steps_per_pick.keys():
                    sp = steps_per_pick[n_pick_k+1]
                else:
                    sp = steps_per_pick[-1]
                if n_picks <= n_pick_k:
                    episode_metrics_per_pick[n_pick_k][current_episode_key] = {**metrics_at_episode_end}
                else:
                    episode_metrics_per_pick[n_pick_k][current_episode_key] = {} 
                    for k in metrics_at_episode_end.keys():
                        if not (k == 'END.num_steps'):
                            episode_metrics_per_pick[n_pick_k][current_episode_key][k] = 0.0
                episode_metrics_per_pick[n_pick_k][current_episode_key]['END.num_steps'] = sp
                        
                average_metrics = self._summarize_metrics(episode_metrics_per_pick[n_pick_k])

                average_metrics['max_picks'] = n_pick_k

                if n_pick_k == 1:
                    self._print_summary(average_metrics)
                    json.dump(average_metrics, f, indent=4)
                else:
                    f.write('\n')
                    json.dump(average_metrics, f, indent=4)

        with open(f"{self.results_dir}/episode_results_per_pick.json", "w") as f:
            json.dump(episode_metrics_per_pick, f, indent=4)

        return episode_metrics_per_pick
            

    def evaluate(
        self,
        agent: "Agent",
        num_episodes: Optional[int] = None,
        evaluation_type: str = "local",
    ) -> Dict[str, float]:
        r"""..

        :param agent: agent to be evaluated in environment.
        :param num_episodes: count of number of episodes for which the
            evaluation should be run.
        :return: dict containing metrics tracked by environment.
        """
        if evaluation_type == EvaluationType.LOCAL.value:
            self._env = create_priv_env_fn(self.config)
            return self.local_evaluate(agent, num_episodes)
        elif evaluation_type == EvaluationType.LOCAL_VECTORIZED.value:
            self._env = create_priv_env_fn(self.config)
            return self.local_evaluate_vectorized(agent, num_episodes)
        elif evaluation_type == EvaluationType.REMOTE.value:
            self._env = None
            return self.remote_evaluate(agent, num_episodes)
        else:
            raise ValueError(
                "Invalid evaluation type. Please choose from 'local', 'local_vectorized', 'remote'"
            )


    def _check_set_planner_vis_dir(
        self, agent: "Agent", current_episode: "BaseEpisode"
    ):
        """
        Sets vis_dir for storing planner's debug visualisations if the agent has a planner.
        """
        if hasattr(agent, "planner"):
            agent.planner.set_vis_dir(
                current_episode.scene_id.split("/")[-1].split(".")[0],
                current_episode.episode_id,
            )
        if hasattr(agent, "our_planner") and hasattr(agent.our_planner, "set_vis_dir"):
            agent.our_planner.set_vis_dir(
                current_episode.scene_id.split("/")[-1].split(".")[0],
                current_episode.episode_id,
            )

        if hasattr(agent, "nav") and hasattr(agent.nav, "set_vis_dir"):
            agent.nav.set_vis_dir(
                current_episode.scene_id.split("/")[-1].split(".")[0],
                current_episode.episode_id,
            )