import argparse
import asyncio
import json
import os
from datetime import datetime

from transformers import AutoTokenizer

from examples.bugs.solver_flow import SolverWorkflow
from examples.bugs.data_utils import load_data
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.reward_fn import code_reward_fn


def evaluate_results(results):
    """Evaluate results and compute summary metrics."""
    total_episodes = len(results)
    if total_episodes == 0:
        print("No results to evaluate.")
        return

    bug_incorrect = 0
    bug_valid = 0
    bug_compile_err = 0
    solver_pass = 0
    solver_reward = 0

    for ep in results:
        m = ep.metrics or {}
        bug_incorrect += int(m.get("bug_incorrect", 0))
        bug_valid += int(m.get("bug_valid", 0))
        bug_compile_err += int(m.get("bug_has_compile_error", 0))
        solver_pass += int(m.get("solver_pass", 0))
        solver_reward += int(m.get("solver_reward", 0))

    print("\n" + "=" * 80)
    print("📊 SOLVER WORKFLOW RESULTS (static bug generator)")
    print("=" * 80)
    print(f"Total episodes: {total_episodes}")
    print("\nOverall Metrics:")
    print(f"  Bug Incorrect Rate: {bug_incorrect}/{total_episodes} ({100*bug_incorrect/total_episodes:.1f}%)")
    print(f"  Bug Valid Rate: {bug_valid}/{total_episodes} ({100*bug_valid/total_episodes:.1f}%)")
    print(f"  Bug Compile-Error Rate: {bug_compile_err}/{total_episodes} ({100*bug_compile_err/total_episodes:.1f}%)")
    print(f"  Solver Pass Rate: {solver_pass}/{total_episodes} ({100*solver_pass/total_episodes:.1f}%)")
    print(f"  Solver Reward Rate: {solver_reward}/{total_episodes} ({100*solver_reward/total_episodes:.1f}%)")
    print("=" * 80)


def print_reward_distributions(results_dict):
    """Print distribution of solver rewards from a results dict list."""
    from collections import Counter

    solver_rewards = []
    solver_passes = []
    bug_incorrects = []
    bug_compile_errs = []

    for episode_dict in results_dict:
        metrics = episode_dict.get("metrics", {}) or {}
        solver_rewards.append(metrics.get("solver_reward", 0.0))
        solver_passes.append(metrics.get("solver_pass", 0.0))
        bug_incorrects.append(metrics.get("bug_incorrect", 0.0))
        bug_compile_errs.append(metrics.get("bug_has_compile_error", 0.0))

    total = len(solver_rewards)
    if total == 0:
        print("No episodes found for reward distribution.")
        return

    sr = Counter(solver_rewards)
    sp = Counter(solver_passes)
    bi = Counter(bug_incorrects)
    bc = Counter(bug_compile_errs)

    print("\n" + "=" * 80)
    print("📈 DISTRIBUTIONS")
    print("=" * 80)
    print(f"Total episodes: {total}")

    print("\n✅ Solver Reward:")
    for v in sorted(sr.keys()):
        c = sr[v]
        print(f"  {v:.1f}: {c:5d} episodes ({100.0*c/total:5.1f}%)")

    print("\n✅ Solver Pass:")
    for v in sorted(sp.keys()):
        c = sp[v]
        print(f"  {v:.1f}: {c:5d} episodes ({100.0*c/total:5.1f}%)")

    print("\n🐛 Bug Incorrect:")
    for v in sorted(bi.keys()):
        c = bi[v]
        print(f"  {v:.1f}: {c:5d} episodes ({100.0*c/total:5.1f}%)")

    print("\n⚠️ Bug Compile Error:")
    for v in sorted(bc.keys()):
        c = bc[v]
        print(f"  {v:.1f}: {c:5d} episodes ({100.0*c/total:5.1f}%)")

    print("=" * 80)


def load_json_results(json_file):
    """Load results from a JSON file."""
    if not os.path.exists(json_file):
        print(f"Error: JSON file not found: {json_file}")
        return None
    print(f"Loading results from: {json_file}")
    with open(json_file, "r") as f:
        results_dict = json.load(f)
    print(f"Loaded {len(results_dict)} episodes")
    return results_dict


def exclude_token_ids(data):
    """Recursively remove prompt_ids and completion_ids from the data structure."""
    if isinstance(data, dict):
        result = {}
        for key, value in data.items():
            if key not in ["prompt_ids", "completion_ids"]:
                result[key] = exclude_token_ids(value)
        return result
    if isinstance(data, list):
        return [exclude_token_ids(item) for item in data]
    return data


def print_sample_episodes(results, n_samples=3):
    """Print some sample episodes with buggy code + solver output."""
    print("\n" + "=" * 80)
    print(f"📝 SAMPLE EPISODES (showing first {min(n_samples, len(results))})")
    print("=" * 80)

    for idx, ep in enumerate(results[:n_samples]):
        print(f"\n--- Episode {idx + 1} ---")
        print(f"Task ID: {ep.id}")
        print(f"Question: {ep.task.get('question', '')[:200]}...")
        buggy = (ep.info or {}).get("buggy_code", "")
        print("\n🐛 Buggy Code (first 300 chars):")
        print(buggy[:300] + "..." if len(buggy) > 300 else buggy)

        solver_traj = ep.trajectories[0] if ep.trajectories else None
        if solver_traj and solver_traj.steps:
            fixed = solver_traj.steps[0].action or ""
            print("\n🔧 Solver Output (first 300 chars):")
            print(fixed[:300] + "..." if len(fixed) > 300 else fixed)
            print(f"Reward: {solver_traj.steps[0].reward}")

        print(f"\nMetrics: {ep.metrics}")
        print(f"Episode Correct: {ep.is_correct}")
        print("-" * 80)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Solver workflow (fix bugs from static generator) on tasks")
    parser.add_argument("--n_tasks", type=int, default=3, help="Number of tasks to process (default: 3)")
    parser.add_argument("--n_repeats", type=int, default=1, help="Number of times to repeat each task (default: 1)")
    parser.add_argument("--split", type=str, default="train", help="Dataset split to use (default: train)")
    parser.add_argument("--dataset", type=str, default="deepcoder_bugs", help="Dataset: deepcoder, bigcodebench, bugbench, etc.")

    # Solver (trainable model, but here just for evaluation)
    parser.add_argument("--model", type=str, default="Qwen/Qwen3-4B", help="Solver model name")
    parser.add_argument("--base_url", type=str, default="http://localhost:30000/v1", help="Solver base URL")
    parser.add_argument("--api_key", type=str, default="None", help="Solver API key (default: None)")
    parser.add_argument("--temperature", type=float, default=0.6, help="Solver sampling temperature")
    parser.add_argument("--top_p", type=float, default=0.95, help="Solver sampling top_p")
    parser.add_argument("--max_prompt_length", type=int, default=8192, help="Max prompt length")
    parser.add_argument("--max_response_length", type=int, default=8192, help="Max response length")
    parser.add_argument("--solver_system_prompt", type=str, default=None, help="Optional system prompt for solver")

    # Static bug generator (e.g. gpt-oss-20b)
    parser.add_argument("--generator_model", type=str, required=True, help="Generator model name (e.g., gpt-oss-20b)")
    parser.add_argument("--generator_base_url", type=str, default="http://localhost:30001/v1", help="Generator base URL")
    parser.add_argument("--generator_api_key", type=str, default=None, help="Generator API key (default: OPENAI_API_KEY or dummy)")
    parser.add_argument("--generator_temperature", type=float, default=0.6, help="Generator sampling temperature")
    parser.add_argument("--generator_top_p", type=float, default=0.95, help="Generator sampling top_p")
    parser.add_argument("--generator_system_prompt", type=str, default=None, help="Optional system prompt for generator")

    # Bug validity policy
    parser.add_argument("--compile_errors_invalid", action="store_true", help="Treat compile errors as invalid bugs")

    # Execution / logging
    parser.add_argument("--n_parallel", type=int, default=32, help="Number of parallel tasks")
    parser.add_argument("--save_results", action="store_true", help="Save results to JSON")
    parser.add_argument("--output_dir", type=str, default="logs", help="Directory to save results")
    parser.add_argument("--print_samples", type=int, default=3, help="Number of sample episodes to print")
    parser.add_argument("--load_json", type=str, default=None, help="Load results JSON and print distributions")

    args = parser.parse_args()

    if args.load_json:
        results_dict = load_json_results(args.load_json)
        if results_dict is not None:
            print_reward_distributions(results_dict)
        raise SystemExit(0)

    os.environ["TOKENIZERS_PARALLELISM"] = "true"

    print(f"Loading tokenizer for solver model: {args.model}")
    solver_tokenizer = AutoTokenizer.from_pretrained(args.model)

    print("Initializing solver rollout engine...")
    solver_engine = OpenAIEngine(
        model=args.model,
        tokenizer=solver_tokenizer,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        base_url=args.base_url,
        api_key=args.api_key,
        sampling_params={
            "temperature": float(args.temperature),
            "top_p": float(args.top_p),
        },
    )

    # Generator engine uses chat-completions endpoint (no tokenizer required).
    api_key = args.generator_api_key
    if api_key is None:
        api_key = os.getenv("OPENAI_API_KEY", "") or ""
    api_key = api_key.strip()
    is_openai_api = "api.openai.com" in str(args.generator_base_url)
    if is_openai_api and not api_key:
        raise ValueError("generator_base_url points to OpenAI API but OPENAI_API_KEY is missing/empty.")
    if not api_key:
        api_key = "EMPTY"

    print("Initializing generator rollout engine...")
    generator_engine = OpenAIEngine(
        model=args.generator_model,
        tokenizer=None,
        base_url=args.generator_base_url,
        api_key=api_key,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        sampling_params={
            "temperature": float(args.generator_temperature),
            "top_p": float(args.generator_top_p),
        },
    )

    print("Creating workflow engine...")
    engine = AgentWorkflowEngine(
        workflow_cls=SolverWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_rollout_engine": generator_engine,
            "generator_system_prompt": args.generator_system_prompt,
            "solver_system_prompt": args.solver_system_prompt,
            "compile_errors_invalid": bool(args.compile_errors_invalid),
        },
        rollout_engine=solver_engine,
        config=None,
        n_parallel_tasks=args.n_parallel,
        retry_limit=1,
    )

    print(f"\nLoading tasks from dataset '{args.dataset}' split '{args.split}'...")
    all_tasks = load_data(dataset_name=args.dataset, split=args.split, n=args.n_repeats)

    if not all_tasks:
        print("No tasks loaded. Exiting.")
        raise SystemExit(1)

    if args.n_tasks > 0:
        all_tasks = all_tasks[: args.n_tasks]

    print(f"Loaded {len(all_tasks)} tasks")
    print("Configuration:")
    print(f"  Solver Model: {args.model}")
    print(f"  Generator Model: {args.generator_model}")
    print(f"  Parallel tasks: {args.n_parallel}")
    print(f"  Solver Temperature: {args.temperature}, Top-p: {args.top_p}")
    print(f"  Generator Temperature: {args.generator_temperature}, Top-p: {args.generator_top_p}")
    print(f"  Max prompt length: {args.max_prompt_length}")
    print(f"  Max response length: {args.max_response_length}")
    print(f"  compile_errors_invalid: {bool(args.compile_errors_invalid)}")

    print(f"\n🚀 Executing workflow on {len(all_tasks)} tasks...")
    results = asyncio.run(engine.execute_tasks(all_tasks))

    print("\n📊 Evaluating results...")
    evaluate_results(results)

    if args.print_samples > 0:
        print_sample_episodes(results, n_samples=args.print_samples)

    if args.save_results:
        os.makedirs(args.output_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(args.output_dir, f"solver_flow_results_{timestamp}.json")
        results_dict = [exclude_token_ids(ep.to_dict()) for ep in results]
        with open(output_file, "w") as f:
            json.dump(results_dict, f, indent=2)
        print(f"\n💾 Results saved to: {output_file}")
        print_reward_distributions(results_dict)

    print("\n✅ Done!")


