import argparse
import asyncio
import json
import os
from datetime import datetime

from transformers import AutoTokenizer

from examples.bugs.data_utils import load_data
from examples.bugs_refactor.frozen_generator_fixer_flow import FrozenGeneratorFixerWorkflow
from examples.bugs_refactor.utils import _get_pregenerated_bug, _resolve_api_key
from rllm.data.dataset import DatasetRegistry
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.reward_fn import code_reward_fn


def evaluate_results(results):
    """Evaluate results and compute summary metrics."""
    total_episodes = len(results)
    if total_episodes == 0:
        print("No results to evaluate.")
        return

    bug_valid = 0
    fixer_pass = 0
    fixer_codegen_pass = 0
    codegen_present = 0

    for ep in results:
        m = ep.metrics or {}
        bug_valid += int(float(m.get("bug_valid", 0.0)) > 0.0)
        fixer_pass += int(float(m.get("fixer_pass", 0.0)) > 0.0)
        if "fixer_codegen_pass" in m:
            codegen_present += 1
            fixer_codegen_pass += int(float(m.get("fixer_codegen_pass", 0.0)) > 0.0)

    print("\n" + "=" * 80)
    print("📊 FROZEN GENERATOR-FIXER WORKFLOW RESULTS (frozen generator -> trainable fixer)")
    print("=" * 80)
    print(f"Total episodes: {total_episodes}")
    print("\nOverall Metrics:")
    print(f"  Bug Valid Rate: {bug_valid}/{total_episodes} ({100*bug_valid/total_episodes:.1f}%)")
    print(f"  Fixer Pass Rate: {fixer_pass}/{total_episodes} ({100*fixer_pass/total_episodes:.1f}%)")
    if codegen_present > 0:
        print(f"  Fixer CodeGen Pass Rate: {fixer_codegen_pass}/{codegen_present} ({100*fixer_codegen_pass/codegen_present:.1f}%) [on episodes with codegen]")
    else:
        print("  Fixer CodeGen Pass Rate: (not computed)")
    print("=" * 80)


def exclude_token_ids(data):
    """Recursively remove prompt_ids and completion_ids from the data structure."""
    if isinstance(data, dict):
        result = {}
        for key, value in data.items():
            if key not in ["prompt_ids", "completion_ids"]:
                result[key] = exclude_token_ids(value)
        return result
    if isinstance(data, list):
        return [exclude_token_ids(item) for item in data]
    return data


def print_sample_episodes(results, n_samples=3):
    """Print some sample episodes with generator bug + fixer output."""
    print("\n" + "=" * 80)
    print(f"📝 SAMPLE EPISODES (showing first {min(n_samples, len(results))})")
    print("=" * 80)

    for idx, ep in enumerate(results[:n_samples]):
        print(f"\n--- Episode {idx + 1} ---")
        print(f"Task ID: {ep.id}")
        print(f"Question: {str(ep.task.get('question', ''))[:200]}...")

        info = ep.info or {}
        buggy_code = info.get("buggy_code", "")
        bug_source = info.get("bug_source", "unknown")
        print(f"\n🧊 Frozen Generator Bug ({bug_source}, first 400 chars):")
        print(buggy_code[:400] + "..." if len(buggy_code) > 400 else buggy_code)

        # Find fixer trajectories
        fixer_trajs = [t for t in ep.trajectories if t.name.startswith("bug_fixer")]
        if fixer_trajs and fixer_trajs[0].steps:
            fixed = fixer_trajs[0].steps[0].action or ""
            print("\n🔧 Fixer Output (first 400 chars):")
            print(fixed[:400] + "..." if len(fixed) > 400 else fixed)
            print(f"Reward: {fixer_trajs[0].steps[0].reward}")

        print(f"\nMetrics: {ep.metrics}")
        print(f"Episode Correct: {ep.is_correct}")
        print("-" * 80)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Frozen Generator-Fixer workflow (repair frozen generator bugs) on tasks")
    parser.add_argument("--n_tasks", type=int, default=-1, help="Number of tasks to process (default: -1 = all)")
    parser.add_argument("--n_repeats", type=int, default=1, help="Number of times to repeat each task (default: 1)")
    parser.add_argument("--split", type=str, default="test", help="Dataset split to use (default: test)")
    parser.add_argument("--dataset", type=str, default="bigcodebench", help="Dataset: deepcoder, bigcodebench, kodcode, bugbench, bugbench_human, bugbench_adversarial, or custom (default: bigcodebench)")

    # Fixer (trainable model, but here just for evaluation)
    parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct", help="Fixer model name")
    parser.add_argument("--base_url", type=str, default="http://localhost:30000/v1", help="Fixer base URL")
    parser.add_argument("--api_key", type=str, default="None", help="Fixer API key (default: None)")
    parser.add_argument("--temperature", type=float, default=0.6, help="Fixer sampling temperature")
    parser.add_argument("--top_p", type=float, default=0.95, help="Fixer sampling top_p")
    parser.add_argument("--max_prompt_length", type=int, default=2048, help="Max prompt length")
    parser.add_argument("--max_response_length", type=int, default=2048, help="Max response length")
    parser.add_argument("--fixer_system_prompt", type=str, default=None, help="Optional system prompt for fixer")

    # Frozen generator (OpenAI-compatible; chat endpoint, no tokenizer required)
    parser.add_argument("--generator_model", type=str, required=True, help="Frozen generator model name")
    parser.add_argument("--generator_base_url", type=str, default="http://localhost:30001/v1", help="Frozen generator base URL")
    parser.add_argument("--generator_api_key", type=str, default=None, help="Frozen generator API key (default: OPENAI_API_KEY or dummy)")
    parser.add_argument("--generator_temperature", type=float, default=0.6, help="Frozen generator sampling temperature")
    parser.add_argument("--generator_top_p", type=float, default=0.95, help="Frozen generator sampling top_p")
    parser.add_argument("--generator_system_prompt", type=str, default=None, help="Optional system prompt for frozen generator")

    # Fixer behavior
    parser.add_argument("--fixer_attempts_train", type=int, default=8, help="Number of fixer attempts during training (default: 8)")
    parser.add_argument("--fixer_attempts_val", type=int, default=1, help="Number of fixer attempts during validation (default: 1)")
    parser.add_argument("--fixer_reward_pm1", action="store_true", default=False, help="Use {-1,+1} reward instead of {0,1}")
    parser.add_argument("--include_failed_test_output", action="store_true", default=True, help="Include unit test failure output in fixer prompt")
    parser.add_argument("--evaluate_codegen", action="store_true", default=True, help="Also evaluate fixer on regular code generation (val-mode only)")

    # Execution / logging
    parser.add_argument("--eval_pregenerated_only", action="store_true",
                        help="Only evaluate on tasks with pregenerated bugs (skip tasks without bugs)")
    parser.add_argument("--val_datasets", nargs="+", default=None,
                        help="Multiple validation datasets in format 'alias=dataset:split' or 'dataset:split' (e.g., 'bugbench:test' 'bcb=bigcodebench:test')")
    parser.add_argument("--n_parallel", type=int, default=32, help="Number of parallel tasks")
    parser.add_argument("--save_results", action="store_true", help="Save results to JSON")
    parser.add_argument("--output_dir", type=str, default="logs", help="Directory to save results")
    parser.add_argument("--print_samples", type=int, default=3, help="Number of sample episodes to print")

    args = parser.parse_args()

    os.environ["TOKENIZERS_PARALLELISM"] = "true"

    print(f"Loading tokenizer for fixer model: {args.model}")
    fixer_tokenizer = AutoTokenizer.from_pretrained(args.model)

    print("Initializing fixer rollout engine...")
    fixer_engine = OpenAIEngine(
        model=args.model,
        tokenizer=fixer_tokenizer,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        base_url=args.base_url,
        api_key=args.api_key,
        sampling_params={
            "temperature": float(args.temperature),
            "top_p": float(args.top_p),
        },
    )
    
    # IMPORTANT: Set validate mode for eval_pregenerated_only (workflow uses pregenerated bugs in validation mode)
    if args.eval_pregenerated_only:
        setattr(fixer_engine, "validate", True)
        print(f"  Set rollout_engine.validate = True (eval mode)")

    generator_api_key = args.generator_api_key
    if generator_api_key is None:
        generator_api_key = os.getenv("OPENAI_API_KEY", "") or ""
    generator_api_key = str(generator_api_key).strip()
    generator_api_key = _resolve_api_key(str(args.generator_base_url), generator_api_key)

    print("Initializing frozen generator engine...")
    generator_engine = OpenAIEngine(
        model=args.generator_model,
        tokenizer=None,  # frozen generator; chat-completions endpoint
        base_url=args.generator_base_url,
        api_key=generator_api_key,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        sampling_params={
            "temperature": float(args.generator_temperature),
            "top_p": float(args.generator_top_p),
        },
        verbose=False,
    )

    print("Creating workflow engine...")
    engine = AgentWorkflowEngine(
        workflow_cls=FrozenGeneratorFixerWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_rollout_engine": generator_engine,
            "generator_system_prompt": args.generator_system_prompt,
            "fixer_system_prompt": args.fixer_system_prompt,
            "fixer_attempts_train": args.fixer_attempts_train,
            "fixer_attempts_val": args.fixer_attempts_val,
            "fixer_reward_pm1": bool(args.fixer_reward_pm1),
            "include_failed_test_output": bool(args.include_failed_test_output),
            "use_pregenerated_bugs_in_validation": True,
            "evaluate_codegen": bool(args.evaluate_codegen),
        },
        rollout_engine=fixer_engine,
        config=None,
        n_parallel_tasks=args.n_parallel,
        retry_limit=1,
    )

    # Parse val_datasets if provided
    val_dataset_specs = {}
    if args.val_datasets:
        for spec in args.val_datasets:
            alias = None
            if "=" in spec:
                alias, spec = spec.split("=", 1)
                alias = alias.strip() or None
                spec = spec.strip()
            
            if ":" in spec:
                ds_name, split = spec.split(":", 1)
                ds_name = ds_name.strip()
                split = split.strip()
            else:
                ds_name, split = spec, "test"
            
            if not alias:
                alias = f"{ds_name}_{split}"
            val_dataset_specs[alias] = (ds_name, split)
    
    # Load tasks - either single dataset or multiple val datasets
    if val_dataset_specs:
        # Multiple validation datasets mode
        print(f"\n📊 Loading {len(val_dataset_specs)} validation datasets...")
        all_results = {}
        
        for alias, (ds_name, split) in val_dataset_specs.items():
            print(f"\n{'='*80}")
            print(f"Evaluating on: {alias} ({ds_name}:{split})")
            print(f"{'='*80}")
            
            tasks = load_data(dataset_name=ds_name, split=split, n=args.n_repeats)
            if not tasks:
                print(f"  WARNING: No tasks loaded from {ds_name}:{split}. Skipping.")
                continue
            
            # Filter to only pregenerated bugs if eval mode is enabled
            if args.eval_pregenerated_only:
                original_count = len(tasks)
                tasks = [task for task in tasks if _get_pregenerated_bug(task) is not None]
                filtered_count = len(tasks)
                print(f"  Filtered to tasks with pregenerated bugs: {filtered_count}/{original_count} tasks")
                if filtered_count == 0:
                    print(f"  WARNING: No tasks with pregenerated bugs found. Skipping {alias}.")
                    continue
            
            if args.n_tasks > 0:
                tasks = tasks[: args.n_tasks]
            
            print(f"  Loaded {len(tasks)} tasks")
            print(f"  Configuration:")
            print(f"    Fixer Model: {args.model}")
            print(f"    Frozen Generator Model: {args.generator_model}")
            print(f"    Parallel tasks: {args.n_parallel}")
            print(f"    Eval pregenerated only: {args.eval_pregenerated_only}")
            
            print(f"\n  🚀 Executing workflow on {len(tasks)} tasks...")
            results = asyncio.run(engine.execute_tasks(tasks))
            
            print(f"\n  📊 Results for {alias}:")
            evaluate_results(results)
            
            all_results[alias] = results
            
            if args.print_samples > 0:
                print(f"\n  📝 Sample episodes for {alias}:")
                print_sample_episodes(results, n_samples=args.print_samples)
        
        # Summary across all datasets
        print("\n" + "=" * 80)
        print("📊 SUMMARY ACROSS ALL VALIDATION DATASETS")
        print("=" * 80)
        for alias, results in all_results.items():
            total = len(results)
            if total == 0:
                continue
            fixer_pass = sum(1 for ep in results if float((ep.metrics or {}).get("fixer_pass", 0.0)) > 0.0)
            bug_valid = sum(1 for ep in results if float((ep.metrics or {}).get("bug_valid", 0.0)) > 0.0)
            print(f"  {alias}: {fixer_pass}/{total} fixer pass ({100*fixer_pass/total:.1f}%), "
                  f"{bug_valid}/{total} bug valid ({100*bug_valid/total:.1f}%)")
        
        if args.save_results:
            os.makedirs(args.output_dir, exist_ok=True)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            for alias, results in all_results.items():
                output_file = os.path.join(args.output_dir, f"frozen_generator_fixer_flow_{alias}_{timestamp}.json")
                results_dict = [exclude_token_ids(ep.to_dict()) for ep in results]
                with open(output_file, "w") as f:
                    json.dump(results_dict, f, indent=2)
                print(f"\n💾 Results for {alias} saved to: {output_file}")
        
    else:
        # Single dataset mode (original behavior)
        print(f"\nLoading tasks from dataset '{args.dataset}' split '{args.split}'...")
        all_tasks = load_data(dataset_name=args.dataset, split=args.split, n=args.n_repeats)
        if not all_tasks:
            print("No tasks loaded. Exiting.")
            raise SystemExit(1)

        # Filter to only pregenerated bugs if eval mode is enabled
        if args.eval_pregenerated_only:
            original_count = len(all_tasks)
            all_tasks = [task for task in all_tasks if _get_pregenerated_bug(task) is not None]
            filtered_count = len(all_tasks)
            print(f"  Filtered to tasks with pregenerated bugs: {filtered_count}/{original_count} tasks")
            if filtered_count == 0:
                print("  ERROR: No tasks with pregenerated bugs found. Exiting.")
                raise SystemExit(1)

        if args.n_tasks > 0:
            all_tasks = all_tasks[: args.n_tasks]

        print(f"Loaded {len(all_tasks)} tasks")
        print("Configuration:")
        print(f"  Fixer Model: {args.model}")
        print(f"  Frozen Generator Model: {args.generator_model}")
        print(f"  Parallel tasks: {args.n_parallel}")
        print(f"  include_failed_test_output: {bool(args.include_failed_test_output)}")
        print(f"  Eval pregenerated only: {args.eval_pregenerated_only}")

        print(f"\n🚀 Executing workflow on {len(all_tasks)} tasks...")
        results = asyncio.run(engine.execute_tasks(all_tasks))

        print("\n📊 Evaluating results...")
        evaluate_results(results)

        if args.print_samples > 0:
            print_sample_episodes(results, n_samples=args.print_samples)

        if args.save_results:
            os.makedirs(args.output_dir, exist_ok=True)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_file = os.path.join(args.output_dir, f"frozen_generator_fixer_flow_results_{timestamp}.json")
            results_dict = [exclude_token_ids(ep.to_dict()) for ep in results]
            with open(output_file, "w") as f:
                json.dump(results_dict, f, indent=2)
            print(f"\n💾 Results saved to: {output_file}")

    print("\n✅ Done!")
