import sys
import argparse
import json
import yaml
import os
import re
import time
import shutil
import pandas as pd
from tqdm import tqdm
import alfworld.agents.environment as environment
import alfworld.agents.modules.generic as generic

from decision_oaif.agents.openai_agent import OpenAIAgent
from decision_oaif.agents.hf_agent import HFAgent
from decision_oaif.agents.hf_spaces_agent import HFSpaceAgent
# from decision_oaif.agents.sglang_server_agent import SGLangServerAgent
from decision_oaif.utils.alfworld import parse_gamefile, extract_task_from_observation
from decision_oaif.utils.parser import (
    parse_reason_and_action_alfworld,
    substitute_placeholders,
)


def parse_and_load_config():
    parser = argparse.ArgumentParser(description="Evaluate agent on alfworld")
    parser.add_argument(
        "--alfworld_config",
        type=str,
        default="configs/env_config/alfworld_config.yaml",
        help="Path to the Alfred base config file",
    )
    # Arguments for evaluation
    parser.add_argument(
        "--eval_config", type=str, help="Path to the evaluation config file"
    )
    # Arguments for rollout
    parser.add_argument(
        "--training_config", type=str, help="Path to the training config file"
    )
    parser.add_argument("--iter", type=int, help="Iteration number")
    args = parser.parse_args()

    # Update sys.argv for compatibility with generic.load_config
    sys.argv = [sys.argv[0], args.alfworld_config]

    if args.eval_config:
        with open(args.eval_config, "r") as f:
            config = yaml.safe_load(f)
    elif args.training_config and args.iter is not None:
        with open(args.training_config, "r") as f:
            training_config = yaml.safe_load(f)
        config = training_config["rollout_student_trajectory"]
        config = substitute_placeholders(config, "{iter}", str(args.iter))
    else:
        parser.error(
            "You must provide either --eval_config or both --training_config and --iter."
        )

    return config


def main():
    config = parse_and_load_config()
    alfworld_config = generic.load_config()

    dstdir = (
        f"{config['logdir']}/{time.strftime('%Y%m%d-%H%M%S')}"
        if not config["exact_path"]
        else config["logdir"]
    )
    os.makedirs(dstdir, exist_ok=True)
    with open(os.path.join(dstdir, "config.yaml"), "w") as f:
        yaml.dump(config, f)

    df_summary = pd.DataFrame()

    for agent_config in config["agents"]:
        # Load the correct agent
        if agent_config["type"] == "openai":
            agent = OpenAIAgent(
                model_id=agent_config["model_id"],
                prompt_template_file=agent_config["prompt_template_file"],
                verbose=config["verbose"],
                debug=config["debug"],
                parse_reason_action_fn=parse_reason_and_action_alfworld,
            )
        elif agent_config["type"] == "hf":
            agent = HFAgent(
                model_id=agent_config["model_id"],
                prompt_template_file=agent_config["prompt_template_file"],
                verbose=config["verbose"],
                debug=config["debug"],
                parse_reason_action_fn=parse_reason_and_action_alfworld,
                max_length=6000,
            )
        elif agent_config["type"] == "hf_space":
            agent = HFSpaceAgent(
                space_id=agent_config["space_id"],
                prompt_template_file=agent_config["prompt_template_file"],
                verbose=config["verbose"],
                debug=config["debug"],
                parse_reason_action_fn=parse_reason_and_action_alfworld,
            )
        # elif agent_config["type"] == "sglang_server":
        #     agent = SGLangServerAgent(
        #         server_url=agent_config["server_url"],
        #         prompt_template_file=agent_config["prompt_template_file"],
        #         verbose=config["verbose"],
        #         debug=config["debug"],
        #         parse_reason_action_fn=parse_reason_and_action_alfworld,
        #     )
        else:
            raise ValueError(f"Unsupported agent type: {agent_config['type']}")

        logdir = (
            f"{dstdir}/{agent_config['model_id']}"
            if not config["exact_path"]
            else config["logdir"]
        )
        os.makedirs(logdir, exist_ok=True)
        print(f"Creating log directory {logdir}")

        # Load the Alfred configuration
        env_type = alfworld_config["env"]["type"]
        env = getattr(environment, env_type)(
            alfworld_config, train_eval=config["eval_set"]
        )
        max_env_idxs = (
            min(env.num_games, config["max_env_idxs"])
            if config["max_env_idxs"]
            else env.num_games
        )
        env = env.init_env(batch_size=1)

        # Iterate over all games
        for env_idx in tqdm(range(max_env_idxs), desc="env_idx"):
            obss, info = env.reset()
            if config["start_env_idx"] and (env_idx < config["start_env_idx"]):
                continue

            gamefile = parse_gamefile(env)
            max_actions = env.batch_env.envs[0].max_episode_steps
            obs = obss[0]
            task = extract_task_from_observation(obs)
            trajectory = []
            agent.reset()
            for _ in tqdm(range(max_actions), desc=f"Actions for env_idx: {env_idx+1}"):
                reason, action = agent.predict_reason_action(
                    task=task,
                    observation=obs,
                    candidate_actions=info["admissible_commands"][0],
                )

                data = {
                    "observation": obs,
                    "candidate_actions": info["admissible_commands"][0],
                    "reason": reason,
                    "action": action,
                }

                obss, scores, dones, info = env.step([action])
                obs, score, done = obss[0], scores[0], dones[0]
                data["score"] = score
                trajectory.append(data)

                if done:
                    break

            log_file_path = os.path.join(logdir, f"{env_idx}.json")
            log = {"gamefile": gamefile, "trajectory": trajectory, "task": task}
            print(f"Saving log to {log_file_path}")
            with open(log_file_path, "w") as log_file:
                json.dump(log, log_file, indent=4)

            summary_data = {
                "env_idx": env_idx,
                "gamefile": gamefile,
                "model_id": agent_config["model_id"],
                "num_actions": len(trajectory),
                "score": score,
            }
            summary_file_path = os.path.join(dstdir, "summary.csv")
            if os.path.exists(summary_file_path):
                df_summary = pd.read_csv(summary_file_path)
            df_summary = pd.concat(
                [df_summary, pd.DataFrame([summary_data])], ignore_index=True
            )
            df_summary.to_csv(summary_file_path, index=False)
            print(f"Current summary:\n {df_summary}")


if __name__ == "__main__":
    main()
