import argparse
import asyncio
import json
import os
from datetime import datetime

from transformers import AutoTokenizer

from generator_solver_flow import GeneratorSolverWorkflow
from data_utils import load_data
from prompts import _build_code_generation_prompt
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.reward_fn import code_reward_fn


def evaluate_results(results):
    """Evaluate the results and compute metrics."""
    from collections import defaultdict

    total_episodes = len(results)
    if total_episodes == 0:
        print("No results to evaluate.")
        return

    # Aggregate metrics
    bug_valid_count = 0
    generator_reward_count = 0
    solver_reward_count = 0

    codegen_present_count = 0
    solver_codegen_pass_count = 0

    bug_total_tests_sum = 0
    bug_passed_tests_sum = 0
    bug_test_counts = []
    
    # Bug similarity metrics
    bug_similarity_scores = []
    bug_similarity_count = 0

    problem_stats = defaultdict(
        lambda: {
            "total": 0,
            "bug_valid": 0,
            "generator_reward": 0,
            "solver_reward": 0,
            "solver_codegen_present": 0,
            "solver_codegen_pass": 0,
        }
    )

    for episode in results:
        metrics = episode.metrics or {}
        problem = episode.task.get("question", "")[:100]  # Truncate for display

        bug_valid_count += int(metrics.get("bug_valid", 0))
        generator_reward_count += int(metrics.get("generator_reward", 0))
        solver_reward_count += int(metrics.get("solver_reward", 0))

        if "solver_codegen_pass" in metrics:
            codegen_present_count += 1
            solver_codegen_pass_count += int(metrics.get("solver_codegen_pass", 0))

        if "bug_total_tests" in metrics:
            bug_total_tests_sum += metrics["bug_total_tests"]
        if "bug_passed_tests" in metrics:
            bug_passed_tests_sum += metrics["bug_passed_tests"]
            bug_test_counts.append(metrics["bug_passed_tests"])
        
        # Bug similarity metrics
        if "bug_similarity_score" in metrics:
            bug_similarity_scores.append(metrics["bug_similarity_score"])
            bug_similarity_count += 1

        # Per-problem stats
        problem_stats[problem]["total"] += 1
        problem_stats[problem]["bug_valid"] += int(metrics.get("bug_valid", 0))
        problem_stats[problem]["generator_reward"] += int(metrics.get("generator_reward", 0))
        problem_stats[problem]["solver_reward"] += int(metrics.get("solver_reward", 0))
        if "solver_codegen_pass" in metrics:
            problem_stats[problem]["solver_codegen_present"] += 1
            problem_stats[problem]["solver_codegen_pass"] += int(metrics.get("solver_codegen_pass", 0))

    # Print summary statistics
    print("\n" + "=" * 80)
    print("📊 GENERATOR-SOLVER WORKFLOW RESULTS")
    print("=" * 80)
    print(f"Total episodes: {total_episodes}")
    print("\nOverall Metrics:")
    print(f"  Generator Reward Rate: {generator_reward_count}/{total_episodes} ({100*generator_reward_count/total_episodes:.1f}%)")
    print(f"  Solver Reward Rate (Bug Fixing): {solver_reward_count}/{total_episodes} ({100*solver_reward_count/total_episodes:.1f}%)")
    print(f"  Bug Valid Rate: {bug_valid_count}/{total_episodes} ({100*bug_valid_count/total_episodes:.1f}%)")

    # Always print codegen stats if present (even if pass count is 0)
    if codegen_present_count > 0:
        print(
            f"  Solver CodeGen Pass Rate: {solver_codegen_pass_count}/{codegen_present_count} "
            f"({100*solver_codegen_pass_count/codegen_present_count:.1f}%)  [on episodes with codegen]"
        )
    else:
        print("  Solver CodeGen Pass Rate: (not computed)  [did you set rollout_engine.validate=True?]")

    if bug_total_tests_sum > 0:
        avg_total_tests = bug_total_tests_sum / total_episodes
        avg_passed_tests = bug_passed_tests_sum / total_episodes
        print("\nBug Test Statistics:")
        print(f"  Average total tests per bug: {avg_total_tests:.2f}")
        print(f"  Average passed tests per bug: {avg_passed_tests:.2f}")
        print(f"  Average failed tests per bug: {avg_total_tests - avg_passed_tests:.2f}")

    # Bug similarity judge statistics
    if bug_similarity_count > 0:
        avg_similarity = sum(bug_similarity_scores) / len(bug_similarity_scores)
        print("\n🎯 Bug Similarity Judge Statistics:")
        print(f"  Episodes with similarity scores: {bug_similarity_count}")
        print(f"  Average similarity score: {avg_similarity:.3f} (0-1 scale)")
        print(f"  Min similarity: {min(bug_similarity_scores):.3f}, Max: {max(bug_similarity_scores):.3f}")

    print("=" * 80)


def print_codegen_results(results, n_show=5, show_prompt=True, show_code=True):
    """
    Print a small sample of codegen outcomes (passes + failures).
    Only prints episodes where codegen was actually run.
    """
    codegen_eps = [ep for ep in results if (ep.metrics or {}).get("solver_codegen_pass", None) is not None]
    if not codegen_eps:
        print("\n🧾 CODEGEN RESULTS: (none found)")
        print("Tip: your workflow gates codegen on `rollout_engine.validate`; make sure it's True.")
        return

    passes = [ep for ep in codegen_eps if int((ep.metrics or {}).get("solver_codegen_pass", 0)) == 1]
    fails = [ep for ep in codegen_eps if int((ep.metrics or {}).get("solver_codegen_pass", 0)) == 0]

    print("\n" + "=" * 80)
    print("🧾 CODEGEN RESULTS (sample)")
    print("=" * 80)
    print(f"Computed on {len(codegen_eps)} episodes | Passes: {len(passes)} | Fails: {len(fails)}")

    def _print_one(ep, label):
        print("\n" + "-" * 80)
        print(f"[{label}] Task ID: {ep.id}")
        q = ep.task.get("question", "")
        print(f"Question (first 200 chars): {q[:200]}{'...' if len(q) > 200 else ''}")

        if show_prompt:
            try:
                prompt = _build_code_generation_prompt(ep.task)
            except Exception as e:
                prompt = f"[error building codegen prompt: {e}]"
            print("\nCodeGen Prompt (first 600 chars):")
            print(prompt[:600] + ("..." if len(prompt) > 600 else ""))

        if show_code:
            codegen_traj = None
            for traj in ep.trajectories:
                if traj.name == "code_generator":
                    codegen_traj = traj
                    break
            if codegen_traj and codegen_traj.steps:
                gen = codegen_traj.steps[0].action or ""
                print("\nGenerated Code (first 600 chars):")
                print(gen[:600] + ("..." if len(gen) > 600 else ""))
            else:
                print("\nGenerated Code: [missing code_generator trajectory]")

        print(f"\nsolver_codegen_pass = {bool(int((ep.metrics or {}).get('solver_codegen_pass', 0)))}")

    # Show some failures first (more actionable), then passes
    for ep in fails[:n_show]:
        _print_one(ep, "FAIL")
    for ep in passes[:max(0, n_show - min(n_show, len(fails)))]:
        _print_one(ep, "PASS")

    print("\n" + "=" * 80)


def print_reward_distributions(results_dict):
    """Print distribution of generator and solver rewards from results dict."""
    from collections import Counter

    generator_rewards = []
    solver_rewards = []
    codegen_vals = []

    for episode_dict in results_dict:
        metrics = episode_dict.get("metrics", {})
        generator_rewards.append(metrics.get("generator_reward", 0.0))
        solver_rewards.append(metrics.get("solver_reward", 0.0))
        if "solver_codegen_pass" in metrics:
            codegen_vals.append(metrics.get("solver_codegen_pass", 0.0))

    total = len(generator_rewards)
    if total == 0:
        print("No episodes found for reward distribution.")
        return

    gen_counter = Counter(generator_rewards)
    solver_counter = Counter(solver_rewards)

    print("\n" + "=" * 80)
    print("📈 REWARD DISTRIBUTIONS")
    print("=" * 80)
    print(f"Total episodes: {total}")

    print("\n🔧 Generator Rewards:")
    for reward_value in sorted(gen_counter.keys()):
        count = gen_counter[reward_value]
        percentage = 100.0 * count / total
        print(f"  {reward_value:.1f}: {count:5d} episodes ({percentage:5.1f}%)")

    print("\n✅ Solver Rewards:")
    for reward_value in sorted(solver_counter.keys()):
        count = solver_counter[reward_value]
        percentage = 100.0 * count / total
        print(f"  {reward_value:.1f}: {count:5d} episodes ({percentage:5.1f}%)")

    joint_dist = Counter(zip(generator_rewards, solver_rewards))
    print("\n🔗 Joint Distribution (Generator, Solver):")
    for (gen_reward, solver_reward) in sorted(joint_dist.keys()):
        count = joint_dist[(gen_reward, solver_reward)]
        percentage = 100.0 * count / total
        print(f"  ({gen_reward:.1f}, {solver_reward:.1f}): {count:5d} episodes ({percentage:5.1f}%)")

    if codegen_vals:
        codegen_counter = Counter(codegen_vals)
        print("\n💻 Solver CodeGen Pass (distribution):")
        for v in sorted(codegen_counter.keys()):
            count = codegen_counter[v]
            percentage = 100.0 * count / len(codegen_vals)
            print(f"  {v:.1f}: {count:5d} episodes ({percentage:5.1f}%)  [on episodes with codegen]")

    print("=" * 80)


def load_json_results(json_file):
    """Load results from a JSON file."""
    if not os.path.exists(json_file):
        print(f"Error: JSON file not found: {json_file}")
        return None

    print(f"Loading results from: {json_file}")
    with open(json_file, "r") as f:
        results_dict = json.load(f)

    print(f"Loaded {len(results_dict)} episodes")
    return results_dict


def exclude_token_ids(data):
    """Recursively remove prompt_ids and completion_ids from the data structure."""
    if isinstance(data, dict):
        result = {}
        for key, value in data.items():
            if key not in ["prompt_ids", "completion_ids"]:
                result[key] = exclude_token_ids(value)
        return result
    elif isinstance(data, list):
        return [exclude_token_ids(item) for item in data]
    else:
        return data


def print_sample_episodes(results, n_samples=3):
    """Print detailed information for a few sample episodes."""
    print("\n" + "=" * 80)
    print(f"📝 SAMPLE EPISODES (showing first {min(n_samples, len(results))})")
    print("=" * 80)

    for idx, episode in enumerate(results[:n_samples]):
        print(f"\n--- Episode {idx + 1} ---")
        print(f"Task ID: {episode.id}")
        print(f"Question: {episode.task.get('question', '')[:200]}...")

        bug_traj = None
        solver_traj = None
        codegen_traj = None
        for traj in episode.trajectories:
            if traj.name == "bug_generator":
                bug_traj = traj
            elif traj.name == "bug_fixer":
                solver_traj = traj
            elif traj.name == "code_generator":
                codegen_traj = traj

        if bug_traj and bug_traj.steps:
            bug_code = bug_traj.steps[0].action
            print("\n🐛 Bug Generator Output (first 300 chars):")
            print(bug_code[:300] + "..." if len(bug_code) > 300 else bug_code)
            print(f"Reward: {bug_traj.steps[0].reward}")

        if solver_traj and solver_traj.steps:
            fixed_code = solver_traj.steps[0].action
            print("\n🔧 Bug Fixer Output (first 300 chars):")
            print(fixed_code[:300] + "..." if len(fixed_code) > 300 else fixed_code)
            print(f"Reward: {solver_traj.steps[0].reward}")

        if codegen_traj and codegen_traj.steps:
            try:
                codegen_prompt = _build_code_generation_prompt(episode.task)
            except Exception as e:
                codegen_prompt = f"[error building codegen prompt: {e}]"
            print("\n🧾 CodeGen Prompt (first 600 chars):")
            print(codegen_prompt[:600] + "..." if len(codegen_prompt) > 600 else codegen_prompt)

            generated_code = codegen_traj.steps[0].action
            print("\n💻 Code Generator Output (first 300 chars):")
            print(generated_code[:300] + "..." if len(generated_code) > 300 else generated_code)
            codegen_pass = int((episode.metrics or {}).get("solver_codegen_pass", 0))
            print(f"Pass: {bool(codegen_pass)}")

        print(f"\nMetrics: {episode.metrics}")
        print(f"Episode Correct: {episode.is_correct}")
        print("-" * 80)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Generator-Solver workflow on DeepCoder tasks")
    parser.add_argument("--n_tasks", type=int, default=3, help="Number of tasks to process (default: 3)")
    parser.add_argument("--n_repeats", type=int, default=1, help="Number of times to repeat each task (default: 1)")
    parser.add_argument("--split", type=str, default="train", help="Dataset split to use (default: train)")
    parser.add_argument("--model", type=str, default="Qwen/Qwen3-4B", help="Model name (default: Qwen/Qwen3-4B)")
    parser.add_argument("--base_url", type=str, default="http://localhost:30000/v1", help="Base URL for the API (default: http://localhost:30000/v1)")
    parser.add_argument("--api_key", type=str, default="None", help="API key (default: None)")
    parser.add_argument("--n_parallel", type=int, default=32, help="Number of parallel tasks (default: 32)")
    parser.add_argument("--max_prompt_length", type=int, default=8192, help="Maximum prompt length (default: 8192)")
    parser.add_argument("--max_response_length", type=int, default=8192, help="Maximum response length (default: 8192)")
    parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature (default: 0.6)")
    parser.add_argument("--top_p", type=float, default=0.95, help="Sampling top_p (default: 0.95)")
    parser.add_argument("--generator_system_prompt", type=str, default=None, help="Optional system prompt for bug generator")
    parser.add_argument("--solver_system_prompt", type=str, default=None, help="Optional system prompt for bug solver")
    parser.add_argument("--evaluate_codegen", action="store_true", help="Also evaluate solver on regular code generation (val-mode only)")
    parser.add_argument("--include_failed_test_output", action="store_true", default=True,
                        help="Include failed test output in solver prompts (default: True)")
    parser.add_argument("--no_failed_test_output", action="store_false", dest="include_failed_test_output",
                        help="Disable including failed test output in solver prompts")
    
    # LLM-as-judge bug similarity options
    parser.add_argument("--use_bug_similarity_judge", action="store_true", default=False,
                        help="Use LLM-as-judge to score similarity between generated and target bugs")
    parser.add_argument("--bug_similarity_reward_weight", type=float, default=0.5,
                        help="Weight for bug similarity auxiliary reward (0-1, default: 0.5)")
    parser.add_argument("--bug_similarity_n_targets", type=int, default=3,
                        help="Number of target bugs to compare against (averaged for final score, default: 3)")
    parser.add_argument("--judge_system_prompt", type=str, default=None,
                        help="Optional system prompt for the bug similarity judge")
    parser.add_argument("--reference_bug_dataset", type=str, default="bugbench",
                        help="Dataset to load reference bugs from for similarity comparison (default: bugbench)")
    parser.add_argument("--reference_bug_split", type=str, default="test",
                        help="Split of reference bug dataset (default: test)")
    
    parser.add_argument("--save_results", action="store_true", help="Save results to JSON file")
    parser.add_argument("--output_dir", type=str, default="logs", help="Directory to save results (default: logs)")
    parser.add_argument("--print_samples", type=int, default=3, help="Number of sample episodes to print in detail (default: 3)")
    parser.add_argument("--print_codegen_samples", type=int, default=5, help="Number of codegen examples to print (default: 5)")
    parser.add_argument("--load_json", type=str, default=None, help="Load results from a JSON file and print reward distributions")
    parser.add_argument("--dataset", type=str, default="deepcoder", help="Dataset: deepcoder, bigcodebench, kodcode, bugbench, or custom (default: deepcoder)")

    args = parser.parse_args()

    if args.load_json:
        results_dict = load_json_results(args.load_json)
        if results_dict is not None:
            print_reward_distributions(results_dict)
        raise SystemExit(0)

    os.environ["TOKENIZERS_PARALLELISM"] = "true"

    print(f"Loading tokenizer for model: {args.model}")
    tokenizer = AutoTokenizer.from_pretrained(args.model)

    print("Initializing rollout engine...")
    rollout_engine = OpenAIEngine(
        model=args.model,
        tokenizer=tokenizer,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        base_url=args.base_url,
        api_key=args.api_key,
        sampling_params={
            "temperature": args.temperature,
            "top_p": args.top_p,
        },
    )

    # IMPORTANT: your workflow gates codegen (and “use pregenerated bug”) on rollout_engine.validate
    is_val_mode = True
    setattr(rollout_engine, "validate", bool(is_val_mode))
    print(f"rollout_engine.validate = {getattr(rollout_engine, 'validate', False)}")

    # Load reference bugs for similarity judge if enabled
    reference_bugs = None
    if args.use_bug_similarity_judge:
        print(f"\nLoading reference bugs from {args.reference_bug_dataset}:{args.reference_bug_split}...")
        ref_bugs_data = load_data(dataset_name=args.reference_bug_dataset, split=args.reference_bug_split, n=1)
        if not ref_bugs_data:
            from rllm.data.dataset import DatasetRegistry
            ds = DatasetRegistry.load_dataset(args.reference_bug_dataset, args.reference_bug_split)
            if ds:
                ref_bugs_data = list(ds.get_data())
        if ref_bugs_data:
            reference_bugs = ref_bugs_data
            print(f"  Loaded {len(reference_bugs)} reference bugs for similarity comparison")
        else:
            print(f"  WARNING: Could not load reference bugs from {args.reference_bug_dataset}:{args.reference_bug_split}")

    print("Creating workflow engine...")
    engine = AgentWorkflowEngine(
        workflow_cls=GeneratorSolverWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_system_prompt": args.generator_system_prompt,
            "solver_system_prompt": args.solver_system_prompt,
            "evaluate_codegen": bool(args.evaluate_codegen),
            "include_failed_test_output": args.include_failed_test_output,
            # LLM-as-judge bug similarity options
            "use_bug_similarity_judge": args.use_bug_similarity_judge,
            "bug_similarity_reward_weight": args.bug_similarity_reward_weight,
            "bug_similarity_n_targets": args.bug_similarity_n_targets,
            "judge_system_prompt": args.judge_system_prompt,
            "reference_bugs": reference_bugs,
        },
        rollout_engine=rollout_engine,
        config=None,
        n_parallel_tasks=args.n_parallel,
        retry_limit=1,
    )

    print(f"\nLoading tasks from dataset '{args.dataset}' split '{args.split}'...")
    all_tasks = load_data(dataset_name=args.dataset, split=args.split, n=args.n_repeats)
    if not all_tasks:
        print("No tasks loaded. Exiting.")
        raise SystemExit(1)

    if args.n_tasks > 0:
        all_tasks = all_tasks[: args.n_tasks]

    print(f"Loaded {len(all_tasks)} tasks")
    print("Configuration:")
    print(f"  Model: {args.model}")
    print(f"  Parallel tasks: {args.n_parallel}")
    print(f"  Temperature: {args.temperature}, Top-p: {args.top_p}")
    print(f"  Max prompt length: {args.max_prompt_length}")
    print(f"  Max response length: {args.max_response_length}")
    print(f"  Evaluate CodeGen: {bool(args.evaluate_codegen)}")
    if args.use_bug_similarity_judge:
        print(f"  Bug Similarity Judge: ENABLED (weight={args.bug_similarity_reward_weight}, n_targets={args.bug_similarity_n_targets})")
        print(f"  Reference Bug Dataset: {args.reference_bug_dataset}:{args.reference_bug_split}")

    print(f"\n🚀 Executing workflow on {len(all_tasks)} tasks...")
    results = asyncio.run(engine.execute_tasks(all_tasks))

    print("\n📊 Evaluating results...")
    evaluate_results(results)

    # Print some codegen outcomes (if computed)
    if args.evaluate_codegen:
        print_codegen_results(results, n_show=args.print_codegen_samples)

    if args.print_samples > 0:
        print_sample_episodes(results, n_samples=args.print_samples)

    if args.save_results:
        os.makedirs(args.output_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(args.output_dir, f"generator_solver_flow_results_{timestamp}.json")

        results_dict = [exclude_token_ids(episode.to_dict()) for episode in results]
        with open(output_file, "w") as f:
            json.dump(results_dict, f, indent=2)

        print(f"\n💾 Results saved to: {output_file}")
        print_reward_distributions(results_dict)

    print("\n✅ Done!")

