import argparse
import os
import dotenv
import json
import sys
from pathlib import Path
from collections import defaultdict

# TODO Install home_robot, home_robot_sim and remove this
sys.path.insert(
    0,
    str(Path(__file__).resolve().parent.parent.parent / "src/home_robot"),
)
sys.path.insert(
    0,
    str(Path(__file__).resolve().parent.parent.parent / "src/home_robot_sim"),
)

from config_utils import get_config 
from habitat.core.env import Env

from home_robot.agent.objectnav_agent.objectnav_agent import ObjectNavAgent
from home_robot.agent.objectnav_agent.objectnav_agent_lang import ObjectNavAgent as ObjectNavAgentLang
from home_robot_sim.env.habitat_objectnav_env.habitat_objectnav_env import (
    HabitatObjectNavEnv,
)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Add an additional argument scene_ids that sets what scenes for this process to evaluate
    parser.add_argument(
        "--scene_ids",
        type=str,
        default="",
        help="Comma separated list of scene ids to evaluate",
    )
    parser.add_argument(
        "--use_language",
        action="store_true",
        help="Use language",
    )
    parser.add_argument(
        "--experiment_name",
        type=str,
        default="debug",
        help="Experiment name",
    )
    parser.add_argument(
        "--habitat_config_path",
        type=str,
        default="objectnav/modular_objectnav_hm3d.yaml",
        help="Path to config yaml",
    )
    parser.add_argument(
        "--baseline_config_path",
        type=str,
        default="projects/habitat_objectnav/configs/agent/hm3d_eval.yaml",
        help="Path to config yaml",
    )
    parser.add_argument(
        "--max_episodes_per_scene",
        type=int,
        default=-1,
        help="Maximum number of episodes to evaluate per scene. Set to -1 for no limit.",
    )
    parser.add_argument(
        "opts",
        default=None,
        nargs=argparse.REMAINDER,
        help="Modify config options from command line",
    )
    parser.add_argument(
        "--openai_key",
        type=str,
        default="",
        help="openai api key",
    )
    parser.add_argument(
        "--openai_org",
        type=str,
        default="",
        help="openai api org",
    )
    parser.add_argument(
        "--print_images",
        action="store_true",
        help="Print images",
    )
    parser.add_argument(
        "--use_gt_semantics",
        action="store_true",
        help="Use ground truth semantics",
    )
    print("Arguments:")
    args = parser.parse_args()
    print(json.dumps(vars(args), indent=4))
    print("-" * 100)

    config = get_config(args.habitat_config_path, args.baseline_config_path)

    print(f"Using config: {args.baseline_config_path}")
    config.NUM_ENVIRONMENTS = 1
    print(f"Using {config.NUM_ENVIRONMENTS} environments")
    config.GROUND_TRUTH_SEMANTICS = args.use_gt_semantics
    print(f"Using ground truth semantics: {config.GROUND_TRUTH_SEMANTICS}")
    if config.GROUND_TRUTH_SEMANTICS == 0:
        if config.SEMANTIC_MODEL == "rednet":
            print("Using rednet")
        elif config.SEMANTIC_MODEL == "detic":
            print("Using detic")
    config.PRINT_IMAGES = args.print_images
    print(f"Printing images: {config.PRINT_IMAGES}")
    config.habitat.dataset.split = "val"
    config.EXP_NAME = args.experiment_name   # Set the experiment name from the argument
    print(f"Experiment name: {config.EXP_NAME}")

    if not args.use_language:
        agent = ObjectNavAgent(config=config)
    else:
        if args.openai_key == "":
            print("Loading openai key from .env")
            dotenv.load_dotenv(".env")
            config.OPENAI_KEY = os.getenv("OPENAI_API_KEY")
            config.OPENAI_ORG = os.getenv("OPENAI_ORG")
        else:
            print("Loading openai key from arguments")
            config.OPENAI_KEY = args.openai_key
            config.OPENAI_ORG = args.openai_org
        agent = ObjectNavAgentLang(config=config)
    env = HabitatObjectNavEnv(Env(config=config), config=config)

    # Create a file to dump the episode metrics
    episode_metrics_filename = f"datadump/{config.EXP_NAME}/episode_metrics.csv"
    # Make the directory if necessary
    Path(episode_metrics_filename).parent.mkdir(parents=True, exist_ok=True)
    # Write the header
    with open(episode_metrics_filename, "w") as f:
        f.write("id, episode_id, scene_id, goal_name, distance_to_goal, success, spl, soft_spl, distance_to_goal_reward\n")

    episodes_skiped = 0
    episode_counts = defaultdict(int)
    for i in range(len(env.habitat_env.episodes)):
        try:
            agent.reset()
            env.reset()

            eval_scene = False
            for scene_id in args.scene_ids.split(","):
                if scene_id.strip() in env.habitat_env.current_episode.scene_id:
                    eval_scene = True
                    break

            if not eval_scene or (args.max_episodes_per_scene != -1 and episode_counts[env.habitat_env.current_episode.scene_id] >= args.max_episodes_per_scene):
                print(f"Skipping scene {env.habitat_env.current_episode.scene_id} episode {env.habitat_env.current_episode.episode_id}")
                continue

            t = 0
            while not env.episode_over:
                t += 1
                obs = env.get_observation()
                action, info = agent.act(obs)
                if config.PRINT_IMAGES:
                    env.apply_action(action, info=info)
                else:
                    env.apply_action(action, info=None)

            episode_counts[env.habitat_env.current_episode.scene_id] += 1

            # Keys are "distance_to_goal", "success", "spl", "soft_spl", "distance_to_goal_reward"
            metrics = env.get_episode_metrics()
            print(metrics)
            with open(episode_metrics_filename, "a") as f:
                f.write(f"{i}, {env.habitat_env.current_episode.episode_id}, {env.habitat_env.current_episode.scene_id}, {obs.task_observations['goal_name']}, {metrics['distance_to_goal']}, {metrics['success']}, {metrics['spl']}, {metrics['soft_spl']}, {metrics['distance_to_goal_reward']}\n")
        except Exception as e:
            episodes_skiped += 1
            print(e)
            print("Skipping episode")

    print(f"Episodes skiped: {episodes_skiped}")