import argparse
import asyncio
import json
import os
from datetime import datetime

from transformers import AutoTokenizer

from examples.bugs.data_utils import load_data
from examples.bugs.fixer_flow import FixerWorkflow
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.reward_fn import code_reward_fn


def evaluate_results(results):
    """Evaluate results and compute summary metrics."""
    total_episodes = len(results)
    if total_episodes == 0:
        print("No results to evaluate.")
        return

    solver_pass = 0
    fix_called = 0
    fix_pass = 0
    fixer_reward = 0

    for ep in results:
        m = ep.metrics or {}
        solver_pass += int(m.get("solver_pass", 0))
        fix_called += int(m.get("fix_called", 0))
        fix_pass += int(float(m.get("fixer_reward", 0.0)) > 0.0)

    print("\n" + "=" * 80)
    print("📊 FIXER WORKFLOW RESULTS (frozen solver -> trainable fixer)")
    print("=" * 80)
    print(f"Total episodes: {total_episodes}")
    print("\nOverall Metrics:")
    print(f"  Solver Pass Rate: {solver_pass}/{total_episodes} ({100*solver_pass/total_episodes:.1f}%)")
    print(f"  Fix Called Rate: {fix_called}/{total_episodes} ({100*fix_called/total_episodes:.1f}%)")
    print(f"  Fix Pass Rate:   {fix_pass}/{fix_called} ({100*fix_pass/fix_called:.1f}%)")
    print("=" * 80)


def exclude_token_ids(data):
    """Recursively remove prompt_ids and completion_ids from the data structure."""
    if isinstance(data, dict):
        result = {}
        for key, value in data.items():
            if key not in ["prompt_ids", "completion_ids"]:
                result[key] = exclude_token_ids(value)
        return result
    if isinstance(data, list):
        return [exclude_token_ids(item) for item in data]
    return data


def print_sample_episodes(results, n_samples=3):
    """Print some sample episodes with solver code + fixer output."""
    print("\n" + "=" * 80)
    print(f"📝 SAMPLE EPISODES (showing first {min(n_samples, len(results))})")
    print("=" * 80)

    for idx, ep in enumerate(results[:n_samples]):
        print(f"\n--- Episode {idx + 1} ---")
        print(f"Task ID: {ep.id}")
        print(f"Question: {str(ep.task.get('question', ''))[:200]}...")

        info = ep.info or {}
        solver_code = info.get("solver_code", "")
        print("\n🧊 Frozen Solver Code (first 400 chars):")
        print(solver_code[:400] + "..." if len(solver_code) > 400 else solver_code)

        fixer_traj = ep.trajectories[0] if ep.trajectories else None
        if fixer_traj and fixer_traj.steps:
            fixed = fixer_traj.steps[0].action or ""
            print("\n🔧 Fixer Output (first 400 chars):")
            print(fixed[:400] + "..." if len(fixed) > 400 else fixed)
            print(f"Reward: {fixer_traj.steps[0].reward}")

        print(f"\nMetrics: {ep.metrics}")
        print(f"Episode Correct: {ep.is_correct}")
        print("-" * 80)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Fixer workflow (repair frozen solver failures) on tasks")
    parser.add_argument("--n_tasks", type=int, default=-1, help="Number of tasks to process (default: -1 = all)")
    parser.add_argument("--n_repeats", type=int, default=1, help="Number of times to repeat each task (default: 1)")
    parser.add_argument("--split", type=str, default="test", help="Dataset split to use (default: test)")
    parser.add_argument("--dataset", type=str, default="bigcodebench", help="Dataset: deepcoder, bigcodebench, bugbench, etc.")

    # Fixer (trainable model, but here just for evaluation)
    parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct", help="Fixer model name")
    parser.add_argument("--base_url", type=str, default="http://localhost:30000/v1", help="Fixer base URL")
    parser.add_argument("--api_key", type=str, default="None", help="Fixer API key (default: None)")
    parser.add_argument("--temperature", type=float, default=0.6, help="Fixer sampling temperature")
    parser.add_argument("--top_p", type=float, default=0.95, help="Fixer sampling top_p")
    parser.add_argument("--max_prompt_length", type=int, default=8192, help="Max prompt length")
    parser.add_argument("--max_response_length", type=int, default=8192, help="Max response length")
    parser.add_argument("--fixer_system_prompt", type=str, default=None, help="Optional system prompt for fixer")

    # Frozen solver (OpenAI-compatible; chat endpoint, no tokenizer required)
    parser.add_argument("--solver_model", type=str, required=True, help="Frozen solver model name")
    parser.add_argument("--solver_base_url", type=str, default="http://localhost:30001/v1", help="Frozen solver base URL")
    parser.add_argument("--solver_api_key", type=str, default=None, help="Frozen solver API key (default: OPENAI_API_KEY or dummy)")
    parser.add_argument("--solver_temperature", type=float, default=0.6, help="Frozen solver sampling temperature")
    parser.add_argument("--solver_top_p", type=float, default=0.95, help="Frozen solver sampling top_p")
    parser.add_argument("--solver_system_prompt", type=str, default=None, help="Optional system prompt for frozen solver")

    # Training behavior (still affects evaluation workflow semantics)
    parser.add_argument("--only_train_on_failures", action="store_true", help="If set, solver passes yield empty trajectories")
    parser.add_argument("--reward_pm1", action="store_true", help="Use {-1,+1} reward instead of {0,1}")
    parser.add_argument("--include_failed_test_output", action="store_true", help="Include unit test failure output in fixer prompt")
    parser.add_argument("--max_failed_test_output_chars", type=int, default=4000, help="Max chars of failure output to include in prompt")

    # Execution / logging
    parser.add_argument("--n_parallel", type=int, default=32, help="Number of parallel tasks")
    parser.add_argument("--save_results", action="store_true", help="Save results to JSON")
    parser.add_argument("--output_dir", type=str, default="logs", help="Directory to save results")
    parser.add_argument("--print_samples", type=int, default=3, help="Number of sample episodes to print")

    args = parser.parse_args()

    os.environ["TOKENIZERS_PARALLELISM"] = "true"

    print(f"Loading tokenizer for fixer model: {args.model}")
    fixer_tokenizer = AutoTokenizer.from_pretrained(args.model)

    print("Initializing fixer rollout engine...")
    fixer_engine = OpenAIEngine(
        model=args.model,
        tokenizer=fixer_tokenizer,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        base_url=args.base_url,
        api_key=args.api_key,
        sampling_params={
            "temperature": float(args.temperature),
            "top_p": float(args.top_p),
        },
    )

    api_key = args.solver_api_key
    if api_key is None:
        api_key = os.getenv("OPENAI_API_KEY", "") or ""
    api_key = str(api_key).strip()
    is_openai_api = "api.openai.com" in str(args.solver_base_url)
    if is_openai_api and not api_key:
        raise ValueError("solver_base_url points to OpenAI API but OPENAI_API_KEY is missing/empty.")
    if not api_key:
        api_key = "EMPTY"

    print("Initializing frozen solver engine...")
    solver_engine = OpenAIEngine(
        model=args.solver_model,
        tokenizer=None,  # frozen solver; chat-completions endpoint
        base_url=args.solver_base_url,
        api_key=api_key,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        sampling_params={
            "temperature": float(args.solver_temperature),
            "top_p": float(args.solver_top_p),
        },
        verbose=False,
    )

    print("Creating workflow engine...")
    engine = AgentWorkflowEngine(
        workflow_cls=FixerWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "solver_rollout_engine": solver_engine,
            "solver_system_prompt": args.solver_system_prompt,
            "fixer_system_prompt": args.fixer_system_prompt,
            "only_train_on_failures": bool(args.only_train_on_failures),
            "reward_pm1": bool(args.reward_pm1),
            "include_failed_test_output": bool(args.include_failed_test_output),
            "max_failed_test_output_chars": int(args.max_failed_test_output_chars),
        },
        rollout_engine=fixer_engine,
        config=None,
        n_parallel_tasks=args.n_parallel,
        retry_limit=1,
    )

    print(f"\nLoading tasks from dataset '{args.dataset}' split '{args.split}'...")
    all_tasks = load_data(dataset_name=args.dataset, split=args.split, n=args.n_repeats)
    if not all_tasks:
        print("No tasks loaded. Exiting.")
        raise SystemExit(1)

    if args.n_tasks > 0:
        all_tasks = all_tasks[: args.n_tasks]

    print(f"Loaded {len(all_tasks)} tasks")
    print("Configuration:")
    print(f"  Fixer Model: {args.model}")
    print(f"  Frozen Solver Model: {args.solver_model}")
    print(f"  Parallel tasks: {args.n_parallel}")
    print(f"  only_train_on_failures: {bool(args.only_train_on_failures)}")
    print(f"  include_failed_test_output: {bool(args.include_failed_test_output)}")

    print(f"\n🚀 Executing workflow on {len(all_tasks)} tasks...")
    results = asyncio.run(engine.execute_tasks(all_tasks))

    print("\n📊 Evaluating results...")
    evaluate_results(results)

    if args.print_samples > 0:
        print_sample_episodes(results, n_samples=args.print_samples)

    if args.save_results:
        os.makedirs(args.output_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(args.output_dir, f"fixer_flow_results_{timestamp}.json")
        results_dict = [exclude_token_ids(ep.to_dict()) for ep in results]
        with open(output_file, "w") as f:
            json.dump(results_dict, f, indent=2)
        print(f"\n💾 Results saved to: {output_file}")

    print("\n✅ Done!")


