import asyncio
import argparse
import os
import json
from transformers import AutoTokenizer
import time

from rllm.data.dataset import DatasetRegistry
from rllm.engine.agent_execution_engine import AgentExecutionEngine
from rllm.trainer.env_agent_mappings import AGENT_CLASS_MAPPING, ENV_CLASS_MAPPING
from rllm.utils import compute_pass_at_k

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def argparse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_type", type=str, default="sudoku")
    parser.add_argument("--agent_type", type=str, default="game_agent")
    # model
    parser.add_argument("--model_path", type=str, default="Qwen/Qwen3-8B")
    parser.add_argument("--base_url", type=str, default="None")
    parser.add_argument("--engine_name", type=str, default="openai")
    # agent args
    parser.add_argument("--max_steps", type=int, default=30)
    parser.add_argument("--use_accumulate_thinking", type=str2bool, default=True)
    parser.add_argument("--history_window", type=int, default=None)
    parser.add_argument("--additional_info_path", type=str, default=None)
    parser.add_argument("--use_multi_turn_format", type=str2bool, default=True)
    # env args
    parser.add_argument("--max_turns", type=int, default=30)
    parser.add_argument("--progress_reward_type", type=str, default="")
    # sampling args
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--max_prompt_length", type=int, default=8192)
    parser.add_argument("--max_response_length", type=int, default=32000)
    parser.add_argument("--disable_thinking", type=bool, default=False)
    parser.add_argument("--reasoning_effort", type=str, default=None)
    # n_parallel_agents
    parser.add_argument("--n_parallel_agents", type=int, default=16)
    # exp_args
    parser.add_argument("--dataset", type=str, default="sudoku_bench")
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--num_examples", type=int, default=-1)
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--save_folder", type=str, default=None)
    parser.add_argument("--save_name", type=str, default=None)
    parser.add_argument("--check_is_exists", action="store_true")
    parser.add_argument("--save_obs_only", action="store_true")
    return parser.parse_args()


def map_env_and_agent(env_name, agent_name):
    env_class = ENV_CLASS_MAPPING[env_name]
    agent_class = AGENT_CLASS_MAPPING[agent_name]
    return env_class, agent_class

def save_results(results, save_path, save_obs_only):
    final_results = []
    for result in results:
        trajectory = result["trajectory"]
        idx = result["idx"]
        steps = []
        rewards = []
        if hasattr(trajectory, "steps"):
            for step in trajectory.steps:
                if save_obs_only:
                    steps.append({
                        "observation": step.observation["observation"],
                        "model_response": step.model_response,
                    })
                else:
                    steps.append({
                        "chat_completions": step.chat_completions,
                    })
                rewards.append(step.reward)
        else:
            steps.append({
                "chat_completions": [],
            })
        try:
            if "task" in result:
                task = result["task"]
            else:
                task = trajectory.task
        except Exception as e:
            task = {}
        try:
            las_obs = trajectory.last_observation
        except:
            las_obs = {}
        final_results.append({
            "task": task["extra_info"] if "extra_info" in task else task,
            "steps": steps,
            "last_observation": las_obs,
            "rewards": rewards,
        })


    with open(save_path, "w") as f:
        json.dump(final_results, f, indent=4)


def main():
    args = argparse_args()
    if args.engine_name == "openai" and "/" not in args.model_path:
        api_key = os.getenv("OPENAI_API_KEY", "None")
    else:
        api_key = os.getenv("OPENAI_API_KEY", "None")

    env_class, agent_class = map_env_and_agent(args.env_type, args.agent_type)

    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    except:
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

    sampling_params = {
        "temperature": args.temperature,
        "top_p": args.top_p,
        "model": args.model_path,
    }
    if args.reasoning_effort is not None:
        assert args.reasoning_effort in ["low", "medium", "high"], "Reasoning effort must be one of low, medium, high"
        sampling_params["reasoning_effort"] = args.reasoning_effort
    # if "/" not in args.model_path:
    #     sampling_params["repetition_penalty"] = 1.0
    
    agent_args = {
        "max_steps": args.max_steps,
        "use_accumulate_thinking": args.use_accumulate_thinking,
        "history_window": args.history_window,
        "additional_info_path": args.additional_info_path,
        "use_multi_turn_format": args.use_multi_turn_format,
    }
    env_args = {
        "max_turns": args.max_turns,
        "progress_reward_type": args.progress_reward_type,
    }
    rollout_engine_args = {
        "base_url": args.base_url,
        "api_key": api_key,
    }

    save_path = os.path.join(args.save_folder, f"{args.save_name}.json")
    os.makedirs(args.save_folder, exist_ok=True)

    if args.check_is_exists:
        if os.path.exists(save_path):
            print(f"Results already exist at {save_path}")
            return

    print("Running inference...")
    # chat_mode = False if "/" in args.model_path else True
    chat_mode = True # do not use parser
    engine = AgentExecutionEngine(
        agent_class=agent_class,
        env_class=env_class,
        agent_args=agent_args,
        env_args=env_args,
        engine_name=args.engine_name,
        tokenizer=tokenizer,
        sampling_params=sampling_params,
        rollout_engine_args=rollout_engine_args,
        max_response_length=args.max_response_length,
        max_prompt_length=args.max_prompt_length,
        n_parallel_agents=args.n_parallel_agents,
        max_steps=args.max_steps,
        enforce_max_prompt_length=True,
        chat_mode=chat_mode,
    )

    tasks = DatasetRegistry.load_dataset(args.dataset, args.split)

    if args.num_examples > 0:
        end_idx = args.start_idx + args.num_examples
        if end_idx > len(tasks):
            end_idx = len(tasks)
        tasks = tasks[args.start_idx:end_idx]

    # logging for time
    start_time = time.time()
    results = asyncio.run(engine.execute_tasks(tasks))
    end_time = time.time()
    # convert to minutes
    time_taken = (end_time - start_time) / 60
    print(f"Time taken: {time_taken} minutes")
    save_results(results, save_path, args.save_obs_only)

if __name__ == "__main__":
    main()