#!/usr/bin/env python3
"""
Test script for the mixed dataset feature in GeneratorSolverWorkflow.

This script verifies that pregenerated bug datasets (in BugBench format) can be
mixed with the main training dataset and that the workflow correctly handles
tasks from both sources.

Usage:
    python examples/bugs/run_generator_solver_flow_mixed.py \
        --main-dataset bigcodebench --main-split train \
        --target-dataset bugbench --target-split test \
        --n_main_tasks 10 --n_target_tasks 5

The script will:
1. Load main training dataset tasks (generator creates bugs)
2. Load pregenerated bug dataset tasks (BugBench format, skip generator)
3. Mix them together
4. Run the workflow on all tasks
5. Report results showing proper handling of both task formats
"""

import argparse
import asyncio
import json
import os
from collections import Counter, defaultdict
from datetime import datetime
from typing import Any, Dict, List

from transformers import AutoTokenizer

from generator_solver_flow import GeneratorSolverWorkflow
from data_utils import load_data
from prompts import _build_code_generation_prompt
from rllm.data.dataset import DatasetRegistry
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.reward_fn import code_reward_fn


def evaluate_results(results, show_by_source: bool = True):
    """Evaluate the results and compute metrics."""
    total_episodes = len(results)
    if total_episodes == 0:
        print("No results to evaluate.")
        return

    # Aggregate metrics
    bug_valid_count = 0
    generator_reward_sum = 0.0
    solver_solve_rate_sum = 0.0
    used_pregenerated_count = 0

    codegen_present_count = 0
    solver_codegen_pass_count = 0

    # Per-source stats
    source_stats = defaultdict(
        lambda: {
            "total": 0,
            "bug_valid": 0,
            "generator_reward_sum": 0.0,
            "solver_solve_rate_sum": 0.0,
            "used_pregenerated": 0,
            "solver_codegen_present": 0,
            "solver_codegen_pass": 0,
        }
    )

    for episode in results:
        metrics = episode.metrics or {}
        
        # Determine source
        task = episode.task or {}
        source = "pregenerated" if metrics.get("used_pregenerated_bug") else "generated"

        bug_valid_count += int(metrics.get("bug_valid", 0))
        generator_reward_sum += float(metrics.get("generator_reward", 0))
        solver_solve_rate_sum += float(metrics.get("solver_solve_rate", 0))
        used_pregenerated_count += int(metrics.get("used_pregenerated_bug", 0))

        if "solver_codegen_pass" in metrics:
            codegen_present_count += 1
            solver_codegen_pass_count += int(metrics.get("solver_codegen_pass", 0))

        # Per-source stats
        source_stats[source]["total"] += 1
        source_stats[source]["bug_valid"] += int(metrics.get("bug_valid", 0))
        source_stats[source]["generator_reward_sum"] += float(metrics.get("generator_reward", 0))
        source_stats[source]["solver_solve_rate_sum"] += float(metrics.get("solver_solve_rate", 0))
        source_stats[source]["used_pregenerated"] += int(metrics.get("used_pregenerated_bug", 0))
        if "solver_codegen_pass" in metrics:
            source_stats[source]["solver_codegen_present"] += 1
            source_stats[source]["solver_codegen_pass"] += int(metrics.get("solver_codegen_pass", 0))

    # Print summary statistics
    print("\n" + "=" * 80)
    print("📊 MIXED DATASET WORKFLOW RESULTS")
    print("=" * 80)
    print(f"Total episodes: {total_episodes}")
    print(f"  - With generated bugs: {total_episodes - used_pregenerated_count}")
    print(f"  - With pregenerated bugs: {used_pregenerated_count}")
    
    print("\nOverall Metrics:")
    print(f"  Bug Valid Rate: {bug_valid_count}/{total_episodes} ({100*bug_valid_count/total_episodes:.1f}%)")
    print(f"  Avg Generator Reward: {generator_reward_sum/total_episodes:.3f}")
    print(f"  Avg Solver Solve Rate: {100*solver_solve_rate_sum/total_episodes:.1f}%")

    if codegen_present_count > 0:
        print(
            f"  Solver CodeGen Pass Rate: {solver_codegen_pass_count}/{codegen_present_count} "
            f"({100*solver_codegen_pass_count/codegen_present_count:.1f}%)"
        )

    if show_by_source and len(source_stats) > 1:
        print("\n--- By Bug Source ---")
        for source, stats in sorted(source_stats.items()):
            total = stats["total"]
            if total == 0:
                continue
            print(f"\n  [{source.upper()}] ({total} episodes)")
            print(f"    Bug Valid Rate: {stats['bug_valid']}/{total} ({100*stats['bug_valid']/total:.1f}%)")
            print(f"    Avg Generator Reward: {stats['generator_reward_sum']/total:.3f}")
            print(f"    Avg Solver Solve Rate: {100*stats['solver_solve_rate_sum']/total:.1f}%")
            if stats["solver_codegen_present"] > 0:
                print(
                    f"    CodeGen Pass Rate: {stats['solver_codegen_pass']}/{stats['solver_codegen_present']} "
                    f"({100*stats['solver_codegen_pass']/stats['solver_codegen_present']:.1f}%)"
                )

    print("=" * 80)


def print_reward_distributions(results_dict):
    """Print distribution of generator and solver rewards from results dict."""
    generator_rewards = []
    solver_rates = []
    used_pregen = []

    for episode_dict in results_dict:
        metrics = episode_dict.get("metrics", {})
        generator_rewards.append(metrics.get("generator_reward", 0.0))
        solver_rates.append(metrics.get("solver_solve_rate", 0.0))
        used_pregen.append(int(metrics.get("used_pregenerated_bug", 0)))

    total = len(generator_rewards)
    if total == 0:
        print("No episodes found for reward distribution.")
        return

    gen_counter = Counter([round(r, 2) for r in generator_rewards])
    solver_counter = Counter([round(r, 2) for r in solver_rates])
    pregen_counter = Counter(used_pregen)

    print("\n" + "=" * 80)
    print("📈 REWARD DISTRIBUTIONS")
    print("=" * 80)
    print(f"Total episodes: {total}")

    print("\n🐛 Used Pregenerated Bug:")
    for val in sorted(pregen_counter.keys()):
        count = pregen_counter[val]
        label = "Yes" if val else "No"
        print(f"  {label}: {count:5d} episodes ({100.0*count/total:5.1f}%)")

    print("\n🔧 Generator Rewards:")
    for reward_value in sorted(gen_counter.keys()):
        count = gen_counter[reward_value]
        percentage = 100.0 * count / total
        print(f"  {reward_value:6.2f}: {count:5d} episodes ({percentage:5.1f}%)")

    print("\n✅ Solver Solve Rates:")
    for rate in sorted(solver_counter.keys()):
        count = solver_counter[rate]
        percentage = 100.0 * count / total
        print(f"  {rate:6.2f}: {count:5d} episodes ({percentage:5.1f}%)")

    print("=" * 80)


def print_sample_episodes(results, n_samples=3):
    """Print detailed information for a few sample episodes."""
    print("\n" + "=" * 80)
    print(f"📝 SAMPLE EPISODES (showing first {min(n_samples, len(results))})")
    print("=" * 80)

    for idx, episode in enumerate(results[:n_samples]):
        metrics = episode.metrics or {}
        used_pregen = bool(metrics.get("used_pregenerated_bug"))
        
        print(f"\n--- Episode {idx + 1} ---")
        print(f"Task ID: {episode.id}")
        print(f"Bug Source: {'PREGENERATED' if used_pregen else 'GENERATED'}")
        print(f"Question: {episode.task.get('question', episode.task.get('instruct_prompt', ''))[:200]}...")

        bug_traj = None
        solver_trajs = []
        codegen_traj = None
        for traj in episode.trajectories:
            if traj.name == "bug_generator":
                bug_traj = traj
            elif traj.name and traj.name.startswith("bug_fixer"):
                solver_trajs.append(traj)
            elif traj.name == "code_generator":
                codegen_traj = traj

        if bug_traj and bug_traj.steps:
            bug_code = bug_traj.steps[0].action
            print("\n🐛 Bug Generator Output (first 300 chars):")
            print(bug_code[:300] + "..." if len(bug_code) > 300 else bug_code)
            print(f"Reward: {bug_traj.steps[0].reward}")
        elif used_pregen:
            print("\n🐛 Pregenerated Bug (generator skipped)")

        if solver_trajs:
            print(f"\n🔧 Solver Attempts: {len(solver_trajs)}")
            for i, traj in enumerate(solver_trajs[:2]):  # Show first 2
                if traj.steps:
                    fixed_code = traj.steps[0].action
                    print(f"  Attempt {i}: reward={traj.steps[0].reward}")
                    if i == 0:
                        print(f"    Code (first 200 chars): {fixed_code[:200]}...")

        if codegen_traj and codegen_traj.steps:
            generated_code = codegen_traj.steps[0].action
            print("\n💻 Code Generator Output (first 300 chars):")
            print(generated_code[:300] + "..." if len(generated_code) > 300 else generated_code)
            codegen_pass = int(metrics.get("solver_codegen_pass", 0))
            print(f"Pass: {bool(codegen_pass)}")

        print(f"\nMetrics: {metrics}")
        print(f"Episode Correct: {episode.is_correct}")
        print("-" * 80)


def exclude_token_ids(data):
    """Recursively remove prompt_ids and completion_ids from the data structure."""
    if isinstance(data, dict):
        result = {}
        for key, value in data.items():
            if key not in ["prompt_ids", "completion_ids"]:
                result[key] = exclude_token_ids(value)
        return result
    elif isinstance(data, list):
        return [exclude_token_ids(item) for item in data]
    else:
        return data


def load_and_mix_datasets(
    main_dataset: str,
    main_split: str,
    target_dataset: str,
    target_split: str,
    n_main_tasks: int,
    n_target_tasks: int,
    n_repeats: int = 1,
) -> List[Dict[str, Any]]:
    """Load and mix main dataset with target dataset (pregenerated bugs)."""
    print(f"\n📦 Loading datasets...")
    
    # Load main dataset
    print(f"  Main dataset: {main_dataset}:{main_split}")
    main_tasks = load_data(dataset_name=main_dataset, split=main_split, n=n_repeats)
    if not main_tasks:
        # Try registry
        ds = DatasetRegistry.load_dataset(main_dataset, main_split)
        if ds:
            main_tasks = list(ds.get_data())
    
    if not main_tasks:
        print(f"  ERROR: Could not load main dataset {main_dataset}:{main_split}")
        return []
    
    if n_main_tasks > 0:
        main_tasks = main_tasks[:n_main_tasks]
    print(f"    Loaded {len(main_tasks)} main tasks")
    
    # Load target dataset (pregenerated bugs)
    print(f"  Target dataset: {target_dataset}:{target_split}")
    target_tasks = load_data(dataset_name=target_dataset, split=target_split, n=n_repeats)
    if not target_tasks:
        ds = DatasetRegistry.load_dataset(target_dataset, target_split)
        if ds:
            target_tasks = list(ds.get_data())
    
    if not target_tasks:
        print(f"  WARNING: Could not load target dataset {target_dataset}:{target_split}")
        target_tasks = []
    
    if n_target_tasks > 0 and target_tasks:
        target_tasks = target_tasks[:n_target_tasks]
    print(f"    Loaded {len(target_tasks)} target tasks")
    
    # Mix
    combined = main_tasks + target_tasks
    print(f"  Combined: {len(combined)} tasks total")
    
    return combined


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Generator-Solver workflow with mixed datasets")
    
    # Dataset args
    parser.add_argument("--main-dataset", type=str, default="bigcodebench", help="Main dataset name (default: bigcodebench)")
    parser.add_argument("--main-split", type=str, default="train", help="Main dataset split (default: train)")
    parser.add_argument("--target-dataset", type=str, default="bugbench", help="Target dataset with pregenerated bugs (default: bugbench)")
    parser.add_argument("--target-split", type=str, default="test", help="Target dataset split (default: test)")
    parser.add_argument("--n_main_tasks", type=int, default=5, help="Number of main tasks (default: 5)")
    parser.add_argument("--n_target_tasks", type=int, default=5, help="Number of target tasks with pregenerated bugs (default: 5)")
    parser.add_argument("--n_repeats", type=int, default=1, help="Number of times to repeat each task (default: 1)")
    
    # Model args
    parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct", help="Model name (default: Qwen/Qwen2.5-Coder-7B-Instruct)")
    parser.add_argument("--base_url", type=str, default="http://localhost:30000/v1", help="Base URL for the API")
    parser.add_argument("--api_key", type=str, default="None", help="API key (default: None)")
    parser.add_argument("--n_parallel", type=int, default=32, help="Number of parallel tasks (default: 32)")
    parser.add_argument("--max_prompt_length", type=int, default=8192, help="Maximum prompt length (default: 8192)")
    parser.add_argument("--max_response_length", type=int, default=8192, help="Maximum response length (default: 8192)")
    parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature (default: 0.6)")
    parser.add_argument("--top_p", type=float, default=0.95, help="Sampling top_p (default: 0.95)")
    
    # Workflow args
    parser.add_argument("--generator_system_prompt", type=str, default=None, help="Optional system prompt for bug generator")
    parser.add_argument("--solver_system_prompt", type=str, default=None, help="Optional system prompt for bug solver")
    parser.add_argument("--evaluate_codegen", action="store_true", help="Also evaluate solver on regular code generation")
    parser.add_argument("--solver_attempts", type=int, default=1, help="Number of solver attempts per bug (default: 1)")
    parser.add_argument("--use_pregenerated_bugs", action="store_true", default=True, 
                        help="Use pregenerated bugs from bug dataset (default: True)")
    parser.add_argument("--no_pregenerated_bugs", action="store_false", dest="use_pregenerated_bugs",
                        help="Disable using pregenerated bugs (generate new bugs for all tasks)")
    parser.add_argument("--include_failed_test_output", action="store_true", default=True,
                        help="Include failed test output in solver prompts (default: True)")
    parser.add_argument("--no_failed_test_output", action="store_false", dest="include_failed_test_output",
                        help="Disable including failed test output in solver prompts")
    
    # Output args
    parser.add_argument("--save_results", action="store_true", help="Save results to JSON file")
    parser.add_argument("--output_dir", type=str, default="logs", help="Directory to save results (default: logs)")
    parser.add_argument("--print_samples", type=int, default=3, help="Number of sample episodes to print (default: 3)")
    parser.add_argument("--load_json", type=str, default=None, help="Load results from JSON file instead of running")

    args = parser.parse_args()

    # Handle loading from JSON
    if args.load_json:
        if not os.path.exists(args.load_json):
            print(f"Error: JSON file not found: {args.load_json}")
            raise SystemExit(1)
        print(f"Loading results from: {args.load_json}")
        with open(args.load_json, "r") as f:
            results_dict = json.load(f)
        print(f"Loaded {len(results_dict)} episodes")
        print_reward_distributions(results_dict)
        raise SystemExit(0)

    os.environ["TOKENIZERS_PARALLELISM"] = "true"

    # Load and mix datasets
    all_tasks = load_and_mix_datasets(
        main_dataset=args.main_dataset,
        main_split=args.main_split,
        target_dataset=args.target_dataset,
        target_split=args.target_split,
        n_main_tasks=args.n_main_tasks,
        n_target_tasks=args.n_target_tasks,
        n_repeats=args.n_repeats,
    )
    
    if not all_tasks:
        print("No tasks loaded. Exiting.")
        raise SystemExit(1)

    print(f"\nLoading tokenizer for model: {args.model}")
    tokenizer = AutoTokenizer.from_pretrained(args.model)

    print("Initializing rollout engine...")
    rollout_engine = OpenAIEngine(
        model=args.model,
        tokenizer=tokenizer,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        base_url=args.base_url,
        api_key=args.api_key,
        sampling_params={
            "temperature": args.temperature,
            "top_p": args.top_p,
        },
    )

    # Set validation mode (affects codegen evaluation)
    is_val_mode = bool(args.evaluate_codegen)
    setattr(rollout_engine, "validate", bool(is_val_mode))
    print(f"rollout_engine.validate = {getattr(rollout_engine, 'validate', False)}")

    print("Creating workflow engine...")
    engine = AgentWorkflowEngine(
        workflow_cls=GeneratorSolverWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_system_prompt": args.generator_system_prompt,
            "solver_system_prompt": args.solver_system_prompt,
            "evaluate_codegen": bool(args.evaluate_codegen),
            "use_pregenerated_bugs_in_training": args.use_pregenerated_bugs,
            "use_pregenerated_bugs_in_validation": args.use_pregenerated_bugs,
            "solver_attempts_train": args.solver_attempts,
            "solver_attempts_val": args.solver_attempts,
            "include_failed_test_output": args.include_failed_test_output,
        },
        rollout_engine=rollout_engine,
        config=None,
        n_parallel_tasks=args.n_parallel,
        retry_limit=1,
    )

    print("\n" + "=" * 80)
    print("🔧 MIXED DATASET TEST CONFIGURATION")
    print("=" * 80)
    print(f"Main dataset: {args.main_dataset}:{args.main_split} ({args.n_main_tasks} tasks)")
    print(f"Target dataset: {args.target_dataset}:{args.target_split} ({args.n_target_tasks} tasks)")
    print(f"Total tasks: {len(all_tasks)}")
    print(f"Use pregenerated bugs: {args.use_pregenerated_bugs}")
    print(f"Model: {args.model}")
    print(f"Parallel tasks: {args.n_parallel}")
    print(f"Temperature: {args.temperature}, Top-p: {args.top_p}")
    print(f"Solver attempts: {args.solver_attempts}")
    print(f"Evaluate CodeGen: {bool(args.evaluate_codegen)}")
    print("=" * 80)

    print(f"\n🚀 Executing workflow on {len(all_tasks)} tasks...")
    results = asyncio.run(engine.execute_tasks(all_tasks))

    print("\n📊 Evaluating results...")
    evaluate_results(results, show_by_source=True)

    if args.print_samples > 0:
        print_sample_episodes(results, n_samples=args.print_samples)

    if args.save_results:
        os.makedirs(args.output_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(args.output_dir, f"mixed_dataset_results_{timestamp}.json")

        results_dict = [exclude_token_ids(episode.to_dict()) for episode in results]
        with open(output_file, "w") as f:
            json.dump(results_dict, f, indent=2)

        print(f"\n💾 Results saved to: {output_file}")
        print_reward_distributions(results_dict)

    print("\n✅ Done!")
