import argparse
import asyncio
import json
import os
from datetime import datetime

from transformers import AutoTokenizer

from examples.bugs.codegen_flow import CodeGenWorkflow
from examples.bugs.data_utils import load_data
from examples.bugs.prompts import _build_code_generation_prompt
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.reward_fn import code_reward_fn


def evaluate_results(results):
    total_episodes = len(results)
    if total_episodes == 0:
        print("No results to evaluate.")
        return

    pass_count = 0
    reward_sum = 0.0

    total_tests_sum = 0
    passed_tests_sum = 0
    tests_present = 0

    for ep in results:
        metrics = ep.metrics or {}
        pass_count += int(metrics.get("codegen_pass", 0))
        reward_sum += float(metrics.get("codegen_reward", 0.0))

        if "total_tests" in metrics:
            total_tests_sum += int(metrics["total_tests"])
            tests_present += 1
        if "passed_tests" in metrics:
            passed_tests_sum += int(metrics["passed_tests"])

    print("\n" + "=" * 80)
    print("📊 CODEGEN WORKFLOW RESULTS")
    print("=" * 80)
    print(f"Total episodes: {total_episodes}")
    print(f"Pass rate: {pass_count}/{total_episodes} ({100.0 * pass_count / total_episodes:.1f}%)")
    print(f"Average reward: {reward_sum / total_episodes:.4f}")

    if tests_present > 0:
        print("\nTest statistics (when available):")
        print(f"  Avg total tests: {total_tests_sum / tests_present:.2f}")
        print(f"  Avg passed tests: {passed_tests_sum / tests_present:.2f}")

    print("=" * 80)


def exclude_token_ids(data):
    """Recursively remove prompt_ids and completion_ids from the data structure."""
    if isinstance(data, dict):
        out = {}
        for k, v in data.items():
            if k not in ["prompt_ids", "completion_ids"]:
                out[k] = exclude_token_ids(v)
        return out
    if isinstance(data, list):
        return [exclude_token_ids(x) for x in data]
    return data


def print_sample_episodes(results, n_samples=3):
    print("\n" + "=" * 80)
    print(f"📝 SAMPLE EPISODES (showing first {min(n_samples, len(results))})")
    print("=" * 80)

    for idx, ep in enumerate(results[:n_samples]):
        print(f"\n--- Episode {idx + 1} ---")
        print(f"Task ID: {ep.id}")
        q = ep.task.get("question", "")
        print(f"Question (first 200 chars): {q[:200]}{'...' if len(q) > 200 else ''}")

        try:
            prompt = _build_code_generation_prompt(ep.task)
        except Exception as e:
            prompt = f"[error building prompt: {e}]"

        print("\nPrompt (first 600 chars):")
        print(prompt[:600] + ("..." if len(prompt) > 600 else ""))

        code_traj = None
        for traj in ep.trajectories:
            if traj.name == "code_generator":
                code_traj = traj
                break
        if code_traj and code_traj.steps:
            gen = code_traj.steps[0].action or ""
            print("\nGenerated (first 600 chars):")
            print(gen[:600] + ("..." if len(gen) > 600 else ""))
        else:
            print("\nGenerated: [missing code_generator trajectory]")

        metrics = ep.metrics or {}
        print(f"\ncodegen_pass = {bool(int(metrics.get('codegen_pass', 0)))}")
        print(f"codegen_reward = {float(metrics.get('codegen_reward', 0.0))}")


def load_json_results(json_file):
    if not os.path.exists(json_file):
        print(f"Error: JSON file not found: {json_file}")
        return None
    print(f"Loading results from: {json_file}")
    with open(json_file, "r") as f:
        return json.load(f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run CodeGen workflow on DeepCoder/BigCodeBench tasks")
    parser.add_argument("--n_tasks", type=int, default=-1, help="Number of tasks to process (default: -1)")
    parser.add_argument("--n_repeats", type=int, default=1, help="Number of times to repeat each task (default: 1)")
    parser.add_argument("--split", type=str, default="all", help="Dataset split to use (default: all)")
    parser.add_argument(
        "--dataset",
        type=str,
        default="bigcodebench",
        help="Dataset to use: deepcoder, bigcodebench, bugbench, or custom dataset name (default: bigcodebench)",
    )

    parser.add_argument(
        "--model",
        type=str,
        default="Qwen/Qwen2.5-Coder-7B-Instruct",
        help="Model name (default: Qwen/Qwen2.5-Coder-7B-Instruct)",
    )
    parser.add_argument("--base_url", type=str, default="http://localhost:30000/v1", help="Base URL for the API")
    parser.add_argument("--api_key", type=str, default="None", help="API key (default: None)")

    parser.add_argument("--n_parallel", type=int, default=32, help="Number of parallel tasks (default: 32)")
    parser.add_argument("--max_prompt_length", type=int, default=8192, help="Maximum prompt length (default: 8192)")
    parser.add_argument("--max_response_length", type=int, default=8192, help="Maximum response length (default: 8192)")
    parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature (default: 0.6)")
    parser.add_argument("--top_p", type=float, default=0.95, help="Sampling top_p (default: 0.95)")

    parser.add_argument("--system_prompt", type=str, default=None, help="Optional system prompt")

    parser.add_argument("--save_results", action="store_true", help="Save results to JSON")
    parser.add_argument("--output_dir", type=str, default="logs", help="Directory to save results (default: logs)")
    parser.add_argument("--print_samples", type=int, default=3, help="Number of sample episodes to print")
    parser.add_argument("--load_json", type=str, default=None, help="Load results JSON and exit")

    args = parser.parse_args()

    if args.load_json:
        data = load_json_results(args.load_json)
        if data is not None:
            print(f"Loaded {len(data)} episodes")
        raise SystemExit(0)

    os.environ["TOKENIZERS_PARALLELISM"] = "true"

    print(f"Loading tokenizer for model: {args.model}")
    tokenizer = AutoTokenizer.from_pretrained(args.model)

    print("Initializing rollout engine...")
    rollout_engine = OpenAIEngine(
        model=args.model,
        tokenizer=tokenizer,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        base_url=args.base_url,
        api_key=args.api_key,
        sampling_params={
            "temperature": args.temperature,
            "top_p": args.top_p,
        },
    )

    # Some workflows gate extra eval on validate; keep consistent.
    setattr(rollout_engine, "validate", bool(args.split.lower() != "train"))

    print("Creating workflow engine...")
    engine = AgentWorkflowEngine(
        workflow_cls=CodeGenWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "system_prompt": args.system_prompt,
        },
        rollout_engine=rollout_engine,
        config=None,
        n_parallel_tasks=args.n_parallel,
        retry_limit=1,
    )

    print(f"\nLoading tasks from dataset '{args.dataset}' split '{args.split}'...")
    all_tasks = load_data(dataset_name=args.dataset, split=args.split, n=args.n_repeats)
    print(f"Loaded {len(all_tasks)} tasks")
    if not all_tasks:
        print("No tasks loaded. Exiting.")
        raise SystemExit(1)

    if args.n_tasks > 0:
        all_tasks = all_tasks[: args.n_tasks]

    print(f"Loaded {len(all_tasks)} tasks")
    print("Configuration:")
    print(f"  Model: {args.model}")
    print(f"  Dataset: {args.dataset}/{args.split}")
    print(f"  Parallel tasks: {args.n_parallel}")
    print(f"  Temperature: {args.temperature}, Top-p: {args.top_p}")
    print(f"  Max prompt length: {args.max_prompt_length}")
    print(f"  Max response length: {args.max_response_length}")

    print(f"\n🚀 Executing workflow on {len(all_tasks)} tasks...")
    results = asyncio.run(engine.execute_tasks(all_tasks))

    print("\n📊 Evaluating results...")
    evaluate_results(results)

    if args.print_samples > 0:
        print_sample_episodes(results, n_samples=args.print_samples)

    if args.save_results:
        os.makedirs(args.output_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(args.output_dir, f"codegen_flow_results_{timestamp}.json")

        results_dict = [exclude_token_ids(ep.to_dict()) for ep in results]
        with open(output_file, "w") as f:
            json.dump(results_dict, f, indent=2)

        print(f"\n💾 Results saved to: {output_file}")

    print("\n✅ Done!")
