import asyncio
import argparse
import os
import json
from datetime import datetime
from transformers import AutoTokenizer

from rllm.data.dataset import DatasetRegistry
from rllm.engine.agent_execution_engine import AgentExecutionEngine
from rllm.trainer.env_agent_mappings import AGENT_CLASS_MAPPING, ENV_CLASS_MAPPING
from rllm.utils import compute_pass_at_k

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def argparse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_type", type=str, default="reasoning_gym")
    parser.add_argument("--agent_type", type=str, default="reasoning_gym_agent")
    # model
    parser.add_argument("--model_path", type=str, default="Qwen/Qwen3-8B")
    parser.add_argument("--base_url", type=str, default="None")
    parser.add_argument("--engine_name", type=str, default="openai")
    # agent args
    parser.add_argument("--use_accumulate_thinking", type=str2bool, default=True)
    parser.add_argument("--additional_info_path", type=str, default=None)
    # sampling args
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--max_prompt_length", type=int, default=8192)
    parser.add_argument("--max_response_length", type=int, default=32000)
    # n_parallel_agents
    parser.add_argument("--n_parallel_agents", type=int, default=16)
    # exp_args
    parser.add_argument("--dataset", type=str, default="leg_counting")
    parser.add_argument("--num_examples", type=int, default=None)
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--save_folder", type=str, default=None)
    parser.add_argument("--save_name", type=str, default=None)
    return parser.parse_args()


def map_env_and_agent(env_name, agent_name):
    env_class = ENV_CLASS_MAPPING[env_name]
    agent_class = AGENT_CLASS_MAPPING[agent_name]
    return env_class, agent_class

def save_results(results, save_folder, save_name):
    save_name_dt = save_name + '_' + datetime.now().strftime("%Y%m%d_%H%M%S")
    os.makedirs(os.path.join(save_folder, save_name_dt), exist_ok=True)
    for idx, result in enumerate(results):
        if result is None:
            traj_log_dict = {}
        else:
            traj_log_dict = result.traj_log_dict
        with open(os.path.join(save_folder, save_name_dt, f"traj_log_{idx}.json"), "w") as f:
            json.dump(traj_log_dict, f, indent=4, ensure_ascii=False)


def main():
    args = argparse_args()
    api_key = os.getenv("OPENAI_API_KEY", "None") 

    env_class, agent_class = map_env_and_agent(args.env_type, args.agent_type)

    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    except:
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

    sampling_params = {
        "temperature": args.temperature,
        "top_p": args.top_p,
        "model": args.model_path,
    }
    agent_args = {
        "use_accumulate_thinking": args.use_accumulate_thinking,
    }
    env_args = {}
    rollout_engine_args = {
        "base_url": args.base_url,
        "api_key": api_key,
    }
    # chat_mode = False if "/" in args.model_path else True
    chat_mode = True # do not use parser
    engine = AgentExecutionEngine(
        agent_class=agent_class,
        env_class=env_class,
        agent_args=agent_args,
        env_args=env_args,
        engine_name=args.engine_name,
        tokenizer=tokenizer,
        sampling_params=sampling_params,
        rollout_engine_args=rollout_engine_args,
        max_response_length=args.max_response_length,
        max_prompt_length=args.max_prompt_length,
        n_parallel_agents=args.n_parallel_agents,
        enforce_max_prompt_length=True,
        chat_mode=chat_mode,
    )

    tasks = DatasetRegistry.load_dataset(args.dataset, "test")

    if args.num_examples > 0:
        tasks = tasks[args.start_idx:args.start_idx+args.num_examples]

    results = asyncio.run(engine.execute_tasks(tasks))

    save_results(results, args.save_folder, args.save_name)

if __name__ == "__main__":
    main()