import argparse
import asyncio
import os
import random
from transformers import AutoTokenizer

import pandas as pd

from rllm.agents.agent import Trajectory
from rllm.data.dataset import DatasetRegistry
from rllm.pipeline.replay import ReplayEngine
from rllm.trainer.env_agent_mappings import ENV_CLASS_MAPPING, AGENT_CLASS_MAPPING

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def load_dataset(dataset_name, split_name):
    return DatasetRegistry.load_dataset(dataset_name, split_name)

def map_env_and_agent(env_type, agent_type):
    env_class = ENV_CLASS_MAPPING[env_type]
    agent_class = AGENT_CLASS_MAPPING[agent_type]
    return env_class, agent_class


async def replay_trajectories(tasks, args):
    env_class, agent_class = map_env_and_agent(args.env_type, args.agent_type)
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    agent_args = {
        "max_steps": args.max_steps,
        "use_accumulate_thinking": args.use_accumulate_thinking,
        "history_window": args.history_window,
        "use_multi_turn_format": args.use_multi_turn_format,
    }
    env_args = {
        "max_turns": args.max_turns,
        "progress_reward_type": args.progress_reward_type,
    }

    replay_engine = ReplayEngine(
        agent_class=agent_class,
        env_class=env_class,
        agent_args=agent_args,
        env_args=env_args,
        tokenizer=tokenizer,
        max_response_length=args.max_response_length,
        max_prompt_length=args.max_prompt_length,
        n_parallel_agents=args.n_parallel_agents,  
        mode=args.mode,
    )

    return await replay_engine.execute_tasks(tasks)

def process_trajectories(results: list[Trajectory], reward_threshold):
    final_sft_dataset = []
    for traj in results:
        # filter by reward threshold
        if traj.reward < reward_threshold:
            continue

        for step in traj.steps:
            messages = step.chat_completions
            clean_messages = []
            for msg in messages:
                if isinstance(msg, dict) and msg.get("role") and str(msg.get("content", "")).strip():
                    clean_messages.append({"role": msg["role"], "content": str(msg["content"]).strip()})
            if len(clean_messages) >=2:
                final_sft_dataset.append({
                    "messages": clean_messages,
                })

    print(f"Processed {len(results)} trajectories -> {len(final_sft_dataset)} valid examples")
    print("Example:")
    print(final_sft_dataset[0])
    return final_sft_dataset
    

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--split", type=str, default="sft")
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--reward_threshold", type=float, default=1.0)
    parser.add_argument("--env_type", type=str, required=True)
    parser.add_argument("--agent_type", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--max_steps", type=int, default=10)
    parser.add_argument("--max_turns", type=int, default=10)
    parser.add_argument("--use_accumulate_thinking", type=str2bool, default=False)
    parser.add_argument("--history_window", type=int, default=None)
    parser.add_argument("--progress_reward_type", type=str, default=None)
    parser.add_argument("--max_response_length", type=int, default=8192)
    parser.add_argument("--max_prompt_length", type=int, default=1024)
    parser.add_argument("--n_parallel_agents", type=int, default=8)
    parser.add_argument("--use_multi_turn_format", type=str2bool, default=False)
    # mode
    parser.add_argument("--mode", type=str, default="sft", choices=["sft", "interactive"])
    args = parser.parse_args()
    
    tasks = load_dataset(args.dataset, args.split)
    print(f"Generating {len(tasks)} trajectories ...")

    results = asyncio.run(replay_trajectories(tasks, args))
    print(results[0].steps)
    print(results[0].reward)

    sft_dataset = process_trajectories(results, args.reward_threshold)
    # sft_dataset = AgentSFTTrainer.process_trajectories(results, args.reward_threshold)

    # makedir if not exists
    output_dir = os.path.dirname(args.output_path)
    os.makedirs(output_dir, exist_ok=True)

    if sft_dataset:
        pd.DataFrame(sft_dataset).to_parquet(args.output_path, index=False)
        # random sample 200 examples for validation
        sft_val_dataset = random.sample(sft_dataset, 200)
        pd.DataFrame(sft_val_dataset).to_parquet(args.output_path.replace(".parquet", "_val.parquet"), index=False)
        lengths = [len(" ".join([m["content"] for m in ex["messages"] if m["role"] == "assistant"])) for ex in sft_dataset]
        print(f"Saved {len(sft_dataset)} examples. Response lengths: min={min(lengths)}, max={max(lengths)}, avg={sum(lengths) // len(lengths)}")
    else:
        print("No valid trajectories found.")

if __name__ == "__main__":
    main()