# examples/bugs/train_generator_solver_with_code_embeddings.py
#
# Train GeneratorSolverWorkflow with *code-embedding similarity* as the auxiliary reward
# for the generator (instead of / in addition to LLM-as-judge).
#
# Usage (Voyage):
#   export VOYAGE_API_KEY="your-api-key"
#   python -m examples.bugs.train_generator_solver_with_code_embeddings
#
# Optional overrides via Hydra (examples):
#   python -m examples.bugs.train_generator_solver_with_code_embeddings \
#     rllm.workflow.workflow_args.use_code_embedding_similarity=true \
#     rllm.workflow.workflow_args.code_embedding_reward_weight=0.2 \
#     rllm.workflow.workflow_args.code_embedding_negative_bug_dataset=bugbench_qwen7b_sampled \
#     rllm.workflow.workflow_args.code_embedding_negative_bug_split=test
#
from __future__ import annotations

import os
import time
from typing import Any, Dict, List, Optional

import hydra
from omegaconf import OmegaConf

from examples.bugs.generator_solver_flow import GeneratorSolverWorkflow
from examples.bugs.data_utils import load_data, parse_val_dataset_specs
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.reward_fn import code_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


def format_time(seconds: float) -> str:
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    if hours > 0:
        return f"{hours}h {minutes}m {secs}s"
    if minutes > 0:
        return f"{minutes}m {secs}s"
    return f"{secs}s"


def _safe_get(cfg: Any, key: str, default: Any) -> Any:
    if cfg is None:
        return default
    return getattr(cfg, key, default)


def _load_tasks_for_pool(dataset_name: str, split: str, max_items: int) -> List[Dict[str, Any]]:
    """
    Try load_data() first (if it implements sampling / filtering),
    then fall back to DatasetRegistry.
    """
    tasks: List[Dict[str, Any]] = []

    try:
        # NOTE: load_data signature is project-specific. This is the same pattern as your script.
        tasks = load_data(dataset_name=dataset_name, split=split, n=max_items)
        tasks = list(tasks) if tasks else []
    except Exception:
        tasks = []

    if not tasks:
        ds = DatasetRegistry.load_dataset(dataset_name, split)
        if ds is None:
            raise ValueError(f"dataset not found: {dataset_name!r} split={split!r}")
        tasks = list(ds.get_data())

    if max_items > 0 and len(tasks) > max_items:
        tasks = tasks[:max_items]
    return tasks


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    start_time = time.time()
    print(f"Starting training at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

    # Ensure workflow config exists and enable workflow mode
    if not hasattr(config.rllm, "workflow"):
        config.rllm.workflow = OmegaConf.create({})
    config.rllm.workflow.use_workflow = True

    workflow_args_cfg = None
    if hasattr(config.rllm, "workflow") and hasattr(config.rllm.workflow, "workflow_args"):
        workflow_args_cfg = config.rllm.workflow.workflow_args

    # -----------------------------
    # Train / Val datasets
    # -----------------------------
    train_dataset = DatasetRegistry.load_dataset("bigcodebench", "train")

    val_specs = parse_val_dataset_specs(workflow_args_cfg)
    if val_specs:
        val_datasets: dict[str, Any] = {}
        for alias, (ds_name, split) in val_specs.items():
            ds = DatasetRegistry.load_dataset(ds_name, split)
            if ds is None:
                raise ValueError(f"val dataset not found: {ds_name!r} split={split!r} (alias={alias!r})")
            val_datasets[alias] = ds
        val_dataset = None
        print(f"Using multiple val datasets: {list(val_datasets.keys())}")
    else:
        val_dataset = DatasetRegistry.load_dataset("bugbench", "test")
        val_datasets = None

    if train_dataset is None:
        print("Failed to load BigCodeBench train dataset. Exiting.")
        print("Available datasets:", DatasetRegistry.get_dataset_names())
        return

    # -----------------------------
    # Optional: concatenate a human-bug dataset into training (train only)
    # -----------------------------
    human_bug_dataset_name: Optional[str] = (str(_safe_get(workflow_args_cfg, "human_bug_dataset_name", "")).strip() or None)
    human_bug_dataset_split: str = str(_safe_get(workflow_args_cfg, "human_bug_dataset_split", "train"))
    combined_train_dataset_name: Optional[str] = (str(_safe_get(workflow_args_cfg, "combined_train_dataset_name", "")).strip() or None)

    if human_bug_dataset_name:
        bug_ds = DatasetRegistry.load_dataset(human_bug_dataset_name, human_bug_dataset_split)
        if bug_ds is None:
            raise ValueError(
                f"human_bug_dataset_name={human_bug_dataset_name!r} split={human_bug_dataset_split!r} not found."
            )

        main_tasks = list(train_dataset.get_data())
        bug_tasks = list(bug_ds.get_data())
        combined_tasks: List[Dict[str, Any]] = main_tasks + bug_tasks

        name = combined_train_dataset_name or f"bigcodebench_plus_{human_bug_dataset_name}"
        print(
            f"Combining training data: bigcodebench ({len(main_tasks)}) + "
            f"{human_bug_dataset_name}/{human_bug_dataset_split} ({len(bug_tasks)}) "
            f"= {len(combined_tasks)} -> registering '{name}'"
        )
        train_dataset = DatasetRegistry.register_dataset(name, combined_tasks, "train")

    # Optional prompts / eval flags
    generator_system_prompt = _safe_get(workflow_args_cfg, "generator_system_prompt", None)
    solver_system_prompt = _safe_get(workflow_args_cfg, "solver_system_prompt", None)
    evaluate_codegen = bool(_safe_get(workflow_args_cfg, "evaluate_codegen", True))

    # -----------------------------
    # SSR-like self-play knobs (forwarded to GeneratorSolverWorkflow)
    # -----------------------------
    solver_attempts_train = int(_safe_get(workflow_args_cfg, "solver_attempts_train", 8))
    solver_attempts_val = int(_safe_get(workflow_args_cfg, "solver_attempts_val", 1))

    generator_reward_mode = str(_safe_get(workflow_args_cfg, "generator_reward_mode", "band"))
    solve_rate_band_low = float(_safe_get(workflow_args_cfg, "solve_rate_band_low", 0.05))
    solve_rate_band_high = float(_safe_get(workflow_args_cfg, "solve_rate_band_high", 0.25))
    gen_alpha_extreme = float(_safe_get(workflow_args_cfg, "gen_alpha_extreme", 0.2))
    gen_invalid_bug_reward = float(_safe_get(workflow_args_cfg, "gen_invalid_bug_reward", -1.0))

    solver_reward_pm1 = bool(_safe_get(workflow_args_cfg, "solver_reward_pm1", False))

    use_pregenerated_bugs_in_validation = bool(_safe_get(workflow_args_cfg, "use_pregenerated_bugs_in_validation", True))

    use_pregenerated_bugs_in_training = bool(
        _safe_get(workflow_args_cfg, "use_pregenerated_bugs_in_training", bool(human_bug_dataset_name))
    )
    pregenerated_bug_train_probability = float(_safe_get(workflow_args_cfg, "pregenerated_bug_train_probability", 1.0))

    episode_success_mode = str(_safe_get(workflow_args_cfg, "episode_success_mode", "bugfix"))
    include_failed_test_output = bool(_safe_get(workflow_args_cfg, "include_failed_test_output", True))

    # -----------------------------
    # Role-conditioned advantage options
    # -----------------------------
    freeze_generator = bool(_safe_get(workflow_args_cfg, "freeze_generator", False))
    freeze_solver = bool(_safe_get(workflow_args_cfg, "freeze_solver", False))
    use_role_advnorm = bool(_safe_get(workflow_args_cfg, "use_role_advnorm", False))

    frozen_roles = []
    if freeze_generator:
        frozen_roles.append("generator")
    if freeze_solver:
        frozen_roles.append("solver")

    # -----------------------------
    # Code-embedding similarity judge (aux reward for generator)
    # -----------------------------
    use_code_embedding_similarity = bool(_safe_get(workflow_args_cfg, "use_code_embedding_similarity", True))
    code_embedding_reward_weight = float(_safe_get(workflow_args_cfg, "code_embedding_reward_weight", 0.2))
    code_embedding_model_name = str(_safe_get(workflow_args_cfg, "code_embedding_model_name", "voyage-code-3"))
    code_embedding_include_problem = bool(_safe_get(workflow_args_cfg, "code_embedding_include_problem", True))
    code_embedding_top_k = int(_safe_get(workflow_args_cfg, "code_embedding_top_k", 5))

    # If you’ve precomputed pools, you can load them instead of rebuilding each run:
    code_embedding_target_pool_path = (str(_safe_get(workflow_args_cfg, "code_embedding_target_pool_path", "")).strip() or None)
    code_embedding_negative_pool_path = (str(_safe_get(workflow_args_cfg, "code_embedding_negative_pool_path", "")).strip() or None)

    # Recommended (if you have a negative pool): use margin->sigmoid
    code_embedding_use_margin = bool(_safe_get(workflow_args_cfg, "code_embedding_use_margin", True))
    code_embedding_margin_temperature = float(_safe_get(workflow_args_cfg, "code_embedding_margin_temperature", 10.0))

    # Where to get target/negative bugs if not loading pools from disk
    code_embedding_target_bug_dataset = str(_safe_get(workflow_args_cfg, "code_embedding_target_bug_dataset", "bugbench"))
    code_embedding_target_bug_split = str(_safe_get(workflow_args_cfg, "code_embedding_target_bug_split", "test"))

    # Set to "" to disable negative pool (then aux = target_score only)
    code_embedding_negative_bug_dataset = str(_safe_get(workflow_args_cfg, "code_embedding_negative_bug_dataset", "bugbench_qwen7b_sampled")).strip()
    code_embedding_negative_bug_split = str(_safe_get(workflow_args_cfg, "code_embedding_negative_bug_split", "test"))

    # Avoid building enormous pools by default
    code_embedding_max_target_items = int(_safe_get(workflow_args_cfg, "code_embedding_max_target_items", 2000))
    code_embedding_max_negative_items = int(_safe_get(workflow_args_cfg, "code_embedding_max_negative_items", 2000))

    reference_bugs: Optional[List[Dict[str, Any]]] = None
    negative_bugs: Optional[List[Dict[str, Any]]] = None

    if use_code_embedding_similarity:
        if code_embedding_model_name.startswith("voyage-"):
            if not os.getenv("VOYAGE_API_KEY"):
                print(
                    "[CodeEmbedding] WARNING: model is Voyage but VOYAGE_API_KEY is not set. "
                    "You will likely get errors from the embedder."
                )

        if not code_embedding_target_pool_path:
            print(f"\n🧠 Loading TARGET bugs for code-embedding pool from {code_embedding_target_bug_dataset}:{code_embedding_target_bug_split}...")
            reference_bugs = _load_tasks_for_pool(
                dataset_name=code_embedding_target_bug_dataset,
                split=code_embedding_target_bug_split,
                max_items=code_embedding_max_target_items,
            )
            print(f"  Loaded {len(reference_bugs)} target tasks for embedding pool build")

        if code_embedding_negative_bug_dataset and not code_embedding_negative_pool_path:
            print(f"\n🧠 Loading NEGATIVE bugs for code-embedding pool from {code_embedding_negative_bug_dataset}:{code_embedding_negative_bug_split}...")
            negative_bugs = _load_tasks_for_pool(
                dataset_name=code_embedding_negative_bug_dataset,
                split=code_embedding_negative_bug_split,
                max_items=code_embedding_max_negative_items,
            )
            print(f"  Loaded {len(negative_bugs)} negative tasks for embedding pool build")
        elif not code_embedding_negative_bug_dataset:
            negative_bugs = None

        print("\n[CodeEmbedding] Configuration:")
        print(f"  enabled: {use_code_embedding_similarity}")
        print(f"  model: {code_embedding_model_name}")
        print(f"  include_problem: {code_embedding_include_problem}")
        print(f"  top_k: {code_embedding_top_k}")
        print(f"  reward_weight: {code_embedding_reward_weight}")
        print(f"  use_margin: {code_embedding_use_margin}")
        print(f"  margin_temperature: {code_embedding_margin_temperature}")
        if code_embedding_target_pool_path:
            print(f"  target_pool_path: {code_embedding_target_pool_path}")
        if code_embedding_negative_pool_path:
            print(f"  negative_pool_path: {code_embedding_negative_pool_path}")

    # -----------------------------
    # Build trainer
    # -----------------------------
    trainer = AgentTrainer(
        workflow_class=GeneratorSolverWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_system_prompt": generator_system_prompt,
            "solver_system_prompt": solver_system_prompt,
            "evaluate_codegen": evaluate_codegen,
            "solver_attempts_train": solver_attempts_train,
            "solver_attempts_val": solver_attempts_val,
            "generator_reward_mode": generator_reward_mode,
            "solve_rate_band_low": solve_rate_band_low,
            "solve_rate_band_high": solve_rate_band_high,
            "gen_alpha_extreme": gen_alpha_extreme,
            "gen_invalid_bug_reward": gen_invalid_bug_reward,
            "solver_reward_pm1": solver_reward_pm1,
            "use_pregenerated_bugs_in_validation": use_pregenerated_bugs_in_validation,
            "use_pregenerated_bugs_in_training": use_pregenerated_bugs_in_training,
            "pregenerated_bug_train_probability": pregenerated_bug_train_probability,
            "episode_success_mode": episode_success_mode,
            "include_failed_test_output": include_failed_test_output,
            # Role-conditioned advantage options
            "freeze_cm": bool(frozen_roles),
            "cm_roles": frozen_roles,
            "use_role_advnorm": use_role_advnorm,
            # ---- Code embedding aux reward ----
            "use_code_embedding_similarity": use_code_embedding_similarity,
            "code_embedding_reward_weight": code_embedding_reward_weight,
            "code_embedding_model_name": code_embedding_model_name,
            "code_embedding_include_problem": code_embedding_include_problem,
            "code_embedding_top_k": code_embedding_top_k,
            "code_embedding_reference_bugs": reference_bugs,   # None if loading pool from disk
            "code_embedding_negative_bugs": negative_bugs,     # optional
            "code_embedding_target_pool_path": code_embedding_target_pool_path,
            "code_embedding_negative_pool_path": code_embedding_negative_pool_path,
            "code_embedding_use_margin": code_embedding_use_margin,
            "code_embedding_margin_temperature": code_embedding_margin_temperature,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        val_datasets=val_datasets,
    )

    try:
        trainer.train()
    except Exception as e:
        print(f"Training failed with error: {e}")
        raise
    finally:
        total_time = time.time() - start_time
        print("\n" + "=" * 60)
        print(f"Training completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Total run time: {format_time(total_time)} ({total_time:.2f} seconds)")
        print("=" * 60)


if __name__ == "__main__":
    main()
