"""Run BugGeneratorWorkflow for evaluation/inference.

This script evaluates a bug generator model against a static fixer baseline.
The generator tries to create bugs that are:
  1. Valid (compile and fail at least one test)
  2. Hard enough that the static fixer cannot fix them

Example usage:
    # Generator only (no fixer baseline)
    python -m examples.bugs_refactor.run_generator_flow \
        --model Qwen/Qwen2.5-Coder-7B-Instruct \
        --base_url http://localhost:30000/v1 \
        --dataset bigcodebench \
        --n_tasks 100

    # Generator with GPT-4o-mini as static fixer
    python -m examples.bugs_refactor.run_generator_flow \
        --model Qwen/Qwen2.5-Coder-7B-Instruct \
        --base_url http://localhost:30000/v1 \
        --fixer_model gpt-4o-mini \
        --dataset bigcodebench \
        --n_tasks 100

    # Generator with local fixer model
    python -m examples.bugs_refactor.run_generator_flow \
        --model Qwen/Qwen2.5-Coder-7B-Instruct \
        --base_url http://localhost:30000/v1 \
        --fixer_model Qwen/Qwen2.5-Coder-32B-Instruct \
        --fixer_base_url http://localhost:30001/v1 \
        --dataset bigcodebench
"""

import argparse
import asyncio
import json
import os
from datetime import datetime
from typing import Any, Dict, List, Optional

from transformers import AutoTokenizer

from examples.bugs.data_utils import load_data
from examples.bugs_refactor.generator_flow import BugGeneratorWorkflow
from rllm.data.dataset import DatasetRegistry
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.reward_fn import code_reward_fn


def evaluate_results(results: List[Any], dataset_alias: Optional[str] = None) -> Dict[str, Any]:
    """Evaluate the results and compute metrics.
    
    Args:
        results: List of Episode objects to evaluate.
        dataset_alias: Optional dataset name for display purposes.
    
    Returns:
        dict: Summary of key metrics.
    """
    total_episodes = len(results)
    if total_episodes == 0:
        print("No results to evaluate.")
        return {}

    # Aggregate metrics
    bug_valid_count = 0
    generator_reward_count = 0
    fixer_pass_count = 0
    compile_error_count = 0

    bug_total_tests_sum = 0
    bug_passed_tests_sum = 0

    for episode in results:
        metrics = episode.metrics or {}

        bug_valid_count += int(float(metrics.get("bug_valid", 0.0)) > 0.0)
        generator_reward_count += int(float(metrics.get("generator_reward", 0.0)) > 0.0)
        fixer_pass_count += int(float(metrics.get("fixer_pass", 0.0)) > 0.0)
        compile_error_count += int(float(metrics.get("bug_has_compile_error", 0.0)) > 0.0)

        if "bug_total_tests" in metrics:
            bug_total_tests_sum += metrics["bug_total_tests"]
        if "bug_passed_tests" in metrics:
            bug_passed_tests_sum += metrics["bug_passed_tests"]

    # Print summary statistics
    print("\n" + "=" * 80)
    header = "📊 BUG GENERATOR WORKFLOW RESULTS"
    if dataset_alias:
        header += f" [{dataset_alias}]"
    print(header)
    print("=" * 80)
    print(f"Total episodes: {total_episodes}")
    print("\nOverall Metrics:")
    print(f"  Bug Valid Rate: {bug_valid_count}/{total_episodes} ({100*bug_valid_count/total_episodes:.1f}%)")
    print(f"  Compile Error Rate: {compile_error_count}/{total_episodes} ({100*compile_error_count/total_episodes:.1f}%)")
    print(f"  Generator Reward Rate: {generator_reward_count}/{total_episodes} ({100*generator_reward_count/total_episodes:.1f}%)")
    print(f"  Fixer Pass Rate: {fixer_pass_count}/{total_episodes} ({100*fixer_pass_count/total_episodes:.1f}%)")

    if bug_total_tests_sum > 0:
        avg_total_tests = bug_total_tests_sum / total_episodes
        avg_passed_tests = bug_passed_tests_sum / total_episodes
        print("\nBug Test Statistics:")
        print(f"  Average total tests per bug: {avg_total_tests:.2f}")
        print(f"  Average passed tests per bug: {avg_passed_tests:.2f}")
        print(f"  Average failed tests per bug: {avg_total_tests - avg_passed_tests:.2f}")

    print("=" * 80)

    # Return summary dict
    summary = {
        "dataset": dataset_alias or "unknown",
        "total_episodes": total_episodes,
        "bug_valid_rate": round(100 * bug_valid_count / total_episodes, 1) if total_episodes > 0 else 0,
        "compile_error_rate": round(100 * compile_error_count / total_episodes, 1) if total_episodes > 0 else 0,
        "generator_reward_rate": round(100 * generator_reward_count / total_episodes, 1) if total_episodes > 0 else 0,
        "fixer_pass_rate": round(100 * fixer_pass_count / total_episodes, 1) if total_episodes > 0 else 0,
        "bug_valid_count": bug_valid_count,
        "generator_reward_count": generator_reward_count,
        "fixer_pass_count": fixer_pass_count,
        "compile_error_count": compile_error_count,
    }
    return summary


def evaluate_results_grouped(results: List[Any]) -> tuple:
    """Evaluate results grouped by dataset alias.
    
    Args:
        results: List of Episode objects, each with task containing '_dataset_alias'.
    
    Returns:
        Tuple of (grouped_results, summaries) where:
        - grouped_results: Dict mapping dataset_alias to list of Episode objects.
        - summaries: Dict mapping dataset_alias to summary dict.
    """
    from collections import defaultdict
    
    # Group results by dataset alias
    grouped = defaultdict(list)
    for episode in results:
        alias = episode.task.get("_dataset_alias", "unknown")
        grouped[alias].append(episode)
    
    # Evaluate each group and collect summaries
    summaries = {}
    for alias, group_results in grouped.items():
        summary = evaluate_results(group_results, dataset_alias=alias)
        summaries[alias] = summary
    
    return dict(grouped), summaries


def exclude_token_ids(data: Any) -> Any:
    """Recursively remove prompt_ids and completion_ids from the data structure."""
    if isinstance(data, dict):
        result = {}
        for key, value in data.items():
            if key not in ["prompt_ids", "completion_ids"]:
                result[key] = exclude_token_ids(value)
        return result
    elif isinstance(data, list):
        return [exclude_token_ids(item) for item in data]
    else:
        return data


def print_sample_episodes(results: List[Any], n_samples: int = 3) -> None:
    """Print detailed information for a few sample episodes."""
    print("\n" + "=" * 80)
    print(f"📝 SAMPLE EPISODES (showing first {min(n_samples, len(results))})")
    print("=" * 80)

    for idx, episode in enumerate(results[:n_samples]):
        print(f"\n--- Episode {idx + 1} ---")
        print(f"Task ID: {episode.id}")
        print(f"Question: {episode.task.get('question', '')[:200]}...")

        bug_traj = None
        fixer_traj = None
        for traj in episode.trajectories:
            if traj.name == "bug_generator":
                bug_traj = traj
            elif traj.name == "bug_fixer":
                fixer_traj = traj

        if bug_traj and bug_traj.steps:
            bug_code = bug_traj.steps[0].action
            print("\n🐛 Bug Generator Output (first 500 chars):")
            print(bug_code[:500] + "..." if len(bug_code) > 500 else bug_code)
            print(f"Reward: {bug_traj.steps[0].reward}")

        if fixer_traj and fixer_traj.steps:
            fixed_code = fixer_traj.steps[0].action
            print("\n🔧 Static Fixer Output (first 500 chars):")
            print(fixed_code[:500] + "..." if len(fixed_code) > 500 else fixed_code)

        metrics = episode.metrics or {}
        print(f"\nMetrics: {metrics}")
        print(f"Episode Correct (bug valid & fixer failed): {episode.is_correct}")
        print("-" * 80)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Bug Generator workflow on tasks")
    
    # Basic args
    parser.add_argument("--n_tasks", type=int, default=-1, help="Number of tasks to process (default: -1 = all)")
    parser.add_argument("--n_repeats", type=int, default=1, help="Number of times to repeat each task (default: 1)")
    parser.add_argument("--split", type=str, default="train", help="Dataset split to use (default: train)")
    parser.add_argument("--dataset", type=str, default="bigcodebench", 
                        help="Dataset: bigcodebench, kodcode, deepcoder, etc. (default: bigcodebench)")
    
    # Generator model args
    parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct", 
                        help="Generator model name (default: Qwen/Qwen2.5-Coder-7B-Instruct)")
    parser.add_argument("--base_url", type=str, default="http://localhost:30000/v1", 
                        help="Base URL for generator API (default: http://localhost:30000/v1)")
    parser.add_argument("--api_key", type=str, default="None", help="API key for generator (default: None)")
    parser.add_argument("--temperature", type=float, default=0.6, help="Generator sampling temperature (default: 0.6)")
    parser.add_argument("--top_p", type=float, default=0.95, help="Generator sampling top_p (default: 0.95)")
    parser.add_argument("--max_prompt_length", type=int, default=8192, help="Maximum prompt length (default: 8192)")
    parser.add_argument("--max_response_length", type=int, default=8192, help="Maximum response length (default: 8192)")
    parser.add_argument("--generator_system_prompt", type=str, default=None, 
                        help="Optional system prompt for bug generator")
    
    # Static fixer model args (optional)
    parser.add_argument("--fixer_model", type=str, default=None,
                        help="Static fixer model name (e.g., 'gpt-4o-mini'). If not provided, no fixer is used.")
    parser.add_argument("--fixer_base_url", type=str, default=None,
                        help="Base URL for fixer API (default: https://api.openai.com/v1 for OpenAI models)")
    parser.add_argument("--fixer_api_key", type=str, default=None,
                        help="API key for fixer (default: uses OPENAI_API_KEY env var for OpenAI models)")
    parser.add_argument("--fixer_temperature", type=float, default=0.0, 
                        help="Fixer sampling temperature (default: 0.0 for deterministic)")
    parser.add_argument("--fixer_top_p", type=float, default=1.0, help="Fixer sampling top_p (default: 1.0)")
    parser.add_argument("--fixer_max_prompt_length", type=int, default=None,
                        help="Max prompt length for fixer (default: same as generator)")
    parser.add_argument("--fixer_max_response_length", type=int, default=None,
                        help="Max response length for fixer (default: same as generator)")
    parser.add_argument("--fixer_system_prompt", type=str, default=None,
                        help="Optional system prompt for bug fixer")
    
    # Execution args
    parser.add_argument("--n_parallel", type=int, default=32, help="Number of parallel tasks (default: 32)")
    
    # Multiple validation datasets
    parser.add_argument("--val_datasets", nargs="+", default=None,
                        help="Multiple validation datasets in format 'alias=dataset:split' or 'dataset:split' "
                             "(e.g., 'bcb=bigcodebench:test' 'kodcode:test')")
    
    # Output args
    parser.add_argument("--save_results", action="store_true", help="Save results to JSON file")
    parser.add_argument("--output_dir", type=str, default="logs", help="Directory to save results (default: logs)")
    parser.add_argument("--print_samples", type=int, default=3, 
                        help="Number of sample episodes to print in detail (default: 3)")

    args = parser.parse_args()

    os.environ["TOKENIZERS_PARALLELISM"] = "true"

    # Generator engine setup
    print(f"Loading tokenizer for generator model: {args.model}")
    tokenizer = AutoTokenizer.from_pretrained(args.model)

    print("Initializing generator rollout engine...")
    generator_rollout_engine = OpenAIEngine(
        model=args.model,
        tokenizer=tokenizer,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        base_url=args.base_url,
        api_key=args.api_key,
        sampling_params={
            "temperature": args.temperature,
            "top_p": args.top_p,
        },
    )

    # Set up fixer rollout engine if fixer_model is specified
    fixer_rollout_engine = None
    if args.fixer_model:
        fixer_base_url = args.fixer_base_url or "https://api.openai.com/v1"
        is_openai_api = "api.openai.com" in fixer_base_url
        
        # Get API key
        if args.fixer_api_key:
            fixer_api_key = args.fixer_api_key
        elif is_openai_api:
            fixer_api_key = os.getenv("OPENAI_API_KEY", "")
            if not fixer_api_key:
                raise ValueError(
                    "fixer_model uses OpenAI API but OPENAI_API_KEY is not set. "
                    "Please export OPENAI_API_KEY or provide --fixer_api_key."
                )
        else:
            fixer_api_key = "EMPTY"
        
        print(f"\nInitializing fixer rollout engine...")
        print(f"  Fixer model: {args.fixer_model}")
        print(f"  Fixer base_url: {fixer_base_url}")
        print(f"  Fixer temperature: {args.fixer_temperature}, top_p: {args.fixer_top_p}")
        
        # Try to load fixer tokenizer, fall back to None for API models
        fixer_tokenizer = None
        if not is_openai_api:
            try:
                fixer_tokenizer = AutoTokenizer.from_pretrained(args.fixer_model)
            except Exception as e:
                print(f"  Warning: Could not load fixer tokenizer: {e}")
        
        fixer_rollout_engine = OpenAIEngine(
            model=args.fixer_model,
            tokenizer=fixer_tokenizer,
            max_prompt_length=args.fixer_max_prompt_length or args.max_prompt_length,
            max_response_length=args.fixer_max_response_length or args.max_response_length,
            base_url=fixer_base_url,
            api_key=fixer_api_key,
            sampling_params={
                "temperature": args.fixer_temperature,
                "top_p": args.fixer_top_p,
            },
        )

    # IMPORTANT: set validate flag for val mode
    setattr(generator_rollout_engine, "validate", True)

    print("\nCreating workflow engine...")
    engine = AgentWorkflowEngine(
        workflow_cls=BugGeneratorWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_system_prompt": args.generator_system_prompt,
            "fixer_rollout_engine": fixer_rollout_engine,
            "fixer_system_prompt": args.fixer_system_prompt,
        },
        rollout_engine=generator_rollout_engine,
        config=None,
        n_parallel_tasks=args.n_parallel,
        retry_limit=1,
    )

    # Parse val_datasets if provided
    val_dataset_specs = {}
    if args.val_datasets:
        for spec in args.val_datasets:
            alias = None
            if "=" in spec:
                alias, spec = spec.split("=", 1)
                alias = alias.strip() or None
                spec = spec.strip()
            
            if ":" in spec:
                ds_name, split = spec.split(":", 1)
                ds_name = ds_name.strip()
                split = split.strip()
            else:
                ds_name, split = spec, "test"
            
            if not alias:
                alias = f"{ds_name}_{split}"
            val_dataset_specs[alias] = (ds_name, split)
    
    # Load tasks - either single dataset or multiple val datasets
    if val_dataset_specs:
        # Multiple validation datasets mode
        print(f"\n📊 Loading {len(val_dataset_specs)} validation datasets...")
        all_tasks = []
        dataset_task_counts = {}
        
        for alias, (ds_name, split) in val_dataset_specs.items():
            print(f"\nLoading: {alias} ({ds_name}:{split})...")
            
            tasks = load_data(dataset_name=ds_name, split=split, n=args.n_repeats)
            if not tasks:
                ds = DatasetRegistry.load_dataset(ds_name, split)
                if ds:
                    tasks = list(ds.get_data())
            
            if not tasks:
                print(f"  WARNING: No tasks loaded from {ds_name}:{split}. Skipping.")
                continue
            
            if args.n_tasks > 0:
                tasks = tasks[:args.n_tasks]
            
            # Tag each task with dataset alias for later grouping
            for task in tasks:
                task["_dataset_alias"] = alias
            
            all_tasks.extend(tasks)
            dataset_task_counts[alias] = len(tasks)
            print(f"  Loaded {len(tasks)} tasks")
        
        if not all_tasks:
            print("No tasks loaded from any dataset. Exiting.")
            raise SystemExit(1)
        
        print(f"\n{'='*80}")
        print(f"Total tasks across all datasets: {len(all_tasks)}")
        for alias, count in dataset_task_counts.items():
            print(f"  {alias}: {count} tasks")
        print(f"{'='*80}")
        print(f"Configuration:")
        print(f"  Generator model: {args.model}")
        if args.fixer_model:
            print(f"  Fixer model: {args.fixer_model}")
        else:
            print(f"  Fixer model: (none - generator only)")
        print(f"  Parallel tasks: {args.n_parallel}")
        print(f"  Generator temperature: {args.temperature}, top_p: {args.top_p}")
        
        print(f"\n🚀 Executing workflow on {len(all_tasks)} tasks...")
        results = asyncio.run(engine.execute_tasks(all_tasks))
        
        # Evaluate results grouped by dataset
        print("\n📊 Evaluating results per dataset...")
        all_results, all_summaries = evaluate_results_grouped(results)
        
        # Print sample episodes per dataset
        if args.print_samples > 0:
            for alias, group_results in all_results.items():
                print(f"\n📝 Sample episodes for {alias}:")
                print_sample_episodes(group_results, n_samples=args.print_samples)
        
        # Summary across all datasets
        print("\n" + "=" * 80)
        print("📊 SUMMARY ACROSS ALL VALIDATION DATASETS")
        print("=" * 80)
        for alias, group_results in all_results.items():
            total = len(group_results)
            if total == 0:
                continue
            gen_reward = sum(1 for ep in group_results if float((ep.metrics or {}).get("generator_reward", 0.0)) > 0.0)
            bug_valid = sum(1 for ep in group_results if float((ep.metrics or {}).get("bug_valid", 0.0)) > 0.0)
            fixer_pass = sum(1 for ep in group_results if float((ep.metrics or {}).get("fixer_pass", 0.0)) > 0.0)
            print(f"  {alias}: {gen_reward}/{total} gen reward ({100*gen_reward/total:.1f}%), "
                  f"{bug_valid}/{total} bug valid ({100*bug_valid/total:.1f}%), "
                  f"{fixer_pass}/{total} fixer pass ({100*fixer_pass/total:.1f}%)")
        
        if args.save_results:
            os.makedirs(args.output_dir, exist_ok=True)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            
            # Save concise summary JSON
            summary_file = os.path.join(args.output_dir, f"generator_flow_summary_{timestamp}.json")
            with open(summary_file, "w") as f:
                json.dump(all_summaries, f, indent=2)
            print(f"\n📊 Summary saved to: {summary_file}")
            
            # Save full results per dataset
            for alias, group_results in all_results.items():
                output_file = os.path.join(args.output_dir, f"generator_flow_{alias}_{timestamp}.json")
                results_dict = [exclude_token_ids(episode.to_dict()) for episode in group_results]
                with open(output_file, "w") as f:
                    json.dump(results_dict, f, indent=2)
                print(f"💾 Results for {alias} saved to: {output_file}")
        
    else:
        # Single dataset mode (original behavior)
        print(f"\nLoading tasks from dataset '{args.dataset}' split '{args.split}'...")
        all_tasks = load_data(dataset_name=args.dataset, split=args.split, n=args.n_repeats)
        if not all_tasks:
            ds = DatasetRegistry.load_dataset(args.dataset, args.split)
            if ds:
                all_tasks = list(ds.get_data())
        
        if not all_tasks:
            print("No tasks loaded. Exiting.")
            raise SystemExit(1)

        if args.n_tasks > 0:
            all_tasks = all_tasks[:args.n_tasks]

        print(f"Loaded {len(all_tasks)} tasks")
        print("\nConfiguration:")
        print(f"  Generator model: {args.model}")
        if args.fixer_model:
            print(f"  Fixer model: {args.fixer_model}")
            print(f"  Fixer base_url: {args.fixer_base_url or 'https://api.openai.com/v1'}")
            print(f"  Fixer temperature: {args.fixer_temperature}, top_p: {args.fixer_top_p}")
        else:
            print(f"  Fixer model: (none - generator only)")
        print(f"  Parallel tasks: {args.n_parallel}")
        print(f"  Generator temperature: {args.temperature}, top_p: {args.top_p}")
        print(f"  Max prompt length: {args.max_prompt_length}")
        print(f"  Max response length: {args.max_response_length}")

        print(f"\n🚀 Executing workflow on {len(all_tasks)} tasks...")
        results = asyncio.run(engine.execute_tasks(all_tasks))

        print("\n📊 Evaluating results...")
        summary = evaluate_results(results, dataset_alias=args.dataset)

        if args.print_samples > 0:
            print_sample_episodes(results, n_samples=args.print_samples)

        if args.save_results:
            os.makedirs(args.output_dir, exist_ok=True)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            
            # Save concise summary JSON
            summary_file = os.path.join(args.output_dir, f"generator_flow_summary_{timestamp}.json")
            with open(summary_file, "w") as f:
                json.dump({args.dataset: summary}, f, indent=2)
            print(f"\n📊 Summary saved to: {summary_file}")
            
            # Save full results
            output_file = os.path.join(args.output_dir, f"generator_flow_results_{timestamp}.json")
            results_dict = [exclude_token_ids(episode.to_dict()) for episode in results]
            with open(output_file, "w") as f:
                json.dump(results_dict, f, indent=2)
            print(f"💾 Results saved to: {output_file}")

    print("\n✅ Done!")
