import sys
import argparse
import json
import yaml
import os
import re
import time
import shutil
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool, Lock

import alfworld.agents.environment as environment
import alfworld.agents.modules.generic as generic

from decision_oaif.agents.openai_agent import OpenAIAgent
from decision_oaif.agents.hf_agent import HFAgent
from decision_oaif.agents.hf_spaces_agent import HFSpaceAgent
from decision_oaif.agents.sglang_server_agent import SGLangServerAgent
from decision_oaif.utils.alfworld import parse_gamefile, extract_task_from_observation
from decision_oaif.utils.parser import (
    parse_reason_and_action_alfworld,
    substitute_placeholders,
)

lock = None


def parse_and_load_config():
    parser = argparse.ArgumentParser(description="Evaluate agent on alfworld")
    parser.add_argument(
        "--alfworld_config",
        type=str,
        default="configs/env_config/alfworld_config.yaml",
        help="Path to the Alfred base config file",
    )
    # Arguments for evaluation
    parser.add_argument(
        "--eval_config", type=str, help="Path to the evaluation config file"
    )
    # Arguments for rollout
    parser.add_argument(
        "--training_config", type=str, help="Path to the training config file"
    )
    parser.add_argument("--iter", type=int, help="Iteration number")
    args = parser.parse_args()

    # Update sys.argv for compatibility with generic.load_config
    sys.argv = [sys.argv[0], args.alfworld_config]

    if args.eval_config:
        with open(args.eval_config, "r") as f:
            config = yaml.safe_load(f)
    elif args.training_config and args.iter is not None:
        with open(args.training_config, "r") as f:
            training_config = yaml.safe_load(f)
        config = training_config["rollout_student_trajectory"]
        config = substitute_placeholders(config, "{iter}", str(args.iter))
    else:
        parser.error(
            "You must provide either --eval_config or both --training_config and --iter."
        )

    return config


def create_env_and_agent(config, agent_config, alfworld_config):
    """Function to create a new environment and agent instance for each worker."""
    # Load the correct agent
    if agent_config["type"] == "openai":
        agent = OpenAIAgent(
            model_id=agent_config["model_id"],
            prompt_template_file=agent_config["prompt_template_file"],
            parse_reason_action_fn=parse_reason_and_action_alfworld,
        )
    elif agent_config["type"] == "hf":
        agent = HFAgent(
            model_id=agent_config["model_id"],
            prompt_template_file=agent_config["prompt_template_file"],
            parse_reason_action_fn=parse_reason_and_action_alfworld,
            max_length=6000,
        )
    elif agent_config["type"] == "hf_space":
        agent = HFSpaceAgent(
            space_id=agent_config["space_id"],
            prompt_template_file=agent_config["prompt_template_file"],
            verbose=config["verbose"],
            debug=config["debug"],
            parse_reason_action_fn=parse_reason_and_action_alfworld,
        )
    elif agent_config["type"] == "sglang_server":
        agent = SGLangServerAgent(
            server_url=agent_config["server_url"],
            prompt_template_file=agent_config["prompt_template_file"],
            verbose=config["verbose"],
            debug=config["debug"],
            parse_reason_action_fn=parse_reason_and_action_alfworld,
        )
    else:
        raise ValueError(f"Unsupported agent type: {agent_config['type']}")

    # Initialize a new environment
    env_type = alfworld_config["env"]["type"]
    env = getattr(environment, env_type)(alfworld_config, train_eval=config["eval_set"])
    env = env.init_env(batch_size=1)

    return env, agent


def process_env_idx(args):
    """Function to process a single environment index."""
    env_idx, config, agent_config, alfworld_config, dstdir = args
    logdir = (
        f"{dstdir}/{agent_config['model_id']}"
        if not config["exact_path"]
        else config["logdir"]
    )
    os.makedirs(logdir, exist_ok=True)

    # Create environment and agent for this worker
    env, agent = create_env_and_agent(config, agent_config, alfworld_config)

    # Reset environment
    obss, info = env.reset()
    if config["start_env_idx"] and (env_idx < config["start_env_idx"]):
        return  # Skip if below start_env_idx

    gamefile = parse_gamefile(env)
    max_actions = env.batch_env.envs[0].max_episode_steps
    obs = obss[0]
    task = extract_task_from_observation(obs)
    trajectory = []
    agent.reset()

    # Process actions for this environment
    # for _ in range(max_actions):
    for _ in tqdm(range(max_actions), desc=f"Actions for env_idx: {env_idx+1}"):
        reason, action = agent.predict_reason_action(
            task=task, observation=obs, candidate_actions=info["admissible_commands"][0]
        )

        data = {
            "observation": obs,
            "candidate_actions": info["admissible_commands"][0],
            "reason": reason,
            "action": action,
        }
        obss, scores, dones, info = env.step([action])
        obs, score, done = obss[0], scores[0], dones[0]
        data["score"] = score
        trajectory.append(data)

        if done:
            break

    # Save the trajectory log to file
    log_file_path = os.path.join(logdir, f"{env_idx}.json")
    log = {"gamefile": gamefile, "trajectory": trajectory, "task": task}
    with open(log_file_path, "w") as log_file:
        json.dump(log, log_file, indent=4)

    # Append to the summary file
    summary_data = {
        "env_idx": env_idx,
        "gamefile": gamefile,
        "model_id": agent_config["model_id"],
        "num_actions": len(trajectory),
        "score": score,
    }
    summary_file_path = os.path.join(dstdir, "summary.csv")

    # Access the global lock and use it to safely append to the summary file
    global lock
    with lock:
        if os.path.exists(summary_file_path):
            df_summary = pd.read_csv(summary_file_path)
        else:
            df_summary = pd.DataFrame()

        df_summary = pd.concat(
            [df_summary, pd.DataFrame([summary_data])], ignore_index=True
        )
        df_summary.to_csv(summary_file_path, index=False)
        print(f"Current summary:\n {df_summary}")


def main():
    global lock
    config = parse_and_load_config()
    alfworld_config = generic.load_config()

    # Create destination directory for logs
    dstdir = (
        f"{config['logdir']}/{time.strftime('%Y%m%d-%H%M%S')}"
        if not config["exact_path"]
        else config["logdir"]
    )
    os.makedirs(dstdir, exist_ok=True)
    with open(os.path.join(dstdir, "config.yaml"), "w") as f:
        yaml.dump(config, f)

    # Initialize the lock at the global level
    lock = Lock()
    args_list = []

    # Prepare arguments for each agent and environment
    for agent_config in config["agents"]:
        env_type = alfworld_config["env"]["type"]
        env = getattr(environment, env_type)(
            alfworld_config, train_eval=config["eval_set"]
        )
        max_env_idxs = (
            min(env.num_games, config["max_env_idxs"])
            if config["max_env_idxs"]
            else env.num_games
        )

        for env_idx in range(max_env_idxs):
            args_list.append((env_idx, config, agent_config, alfworld_config, dstdir))

    # option 1. pool imap to parallelize tasks (no progress bar)
    with Pool(processes=config["num_processes"]) as pool:
        list(tqdm(pool.imap(process_env_idx, args_list), total=len(args_list)))


if __name__ == "__main__":
    main()
