import time
from typing import Any, Dict, List, Optional

import hydra
from omegaconf import OmegaConf

from examples.bugs.generator_solver_flow import GeneratorSolverWorkflow
from examples.bugs.data_utils import load_data, parse_val_dataset_specs
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.reward_fn import code_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


def format_time(seconds: float) -> str:
    """Format time in seconds to human readable format."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)

    if hours > 0:
        return f"{hours}h {minutes}m {secs}s"
    if minutes > 0:
        return f"{minutes}m {secs}s"
    return f"{secs}s"


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    start_time = time.time()
    print(f"Starting training at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

    # Ensure workflow config exists and enable workflow mode
    if not hasattr(config.rllm, "workflow"):
        config.rllm.workflow = OmegaConf.create({})
    config.rllm.workflow.use_workflow = True

    # Resolve workflow_args safely
    workflow_args_cfg = None
    if hasattr(config.rllm, "workflow") and hasattr(config.rllm.workflow, "workflow_args"):
        workflow_args_cfg = config.rllm.workflow.workflow_args

    # -----------------------------
    # Train: BigCodeBench (train)
    # Eval:  BugBench (test) by default; can override with workflow_args.val_datasets
    # -----------------------------
    train_dataset = DatasetRegistry.load_dataset("bigcodebench", "train")
    val_specs = parse_val_dataset_specs(workflow_args_cfg)
    if val_specs:
        val_datasets: dict[str, Any] = {}
        for alias, (ds_name, split) in val_specs.items():
            ds = DatasetRegistry.load_dataset(ds_name, split)
            if ds is None:
                raise ValueError(f"val dataset not found: {ds_name!r} split={split!r} (alias={alias!r})")
            val_datasets[alias] = ds
        val_dataset = None
        print(f"Using multiple val datasets: {list(val_datasets.keys())}")
    else:
        val_dataset = DatasetRegistry.load_dataset("bugbench", "test")
        val_datasets = None

    if train_dataset is None:
        print("Failed to load BigCodeBench train dataset. Exiting.")
        print("Available datasets:", DatasetRegistry.get_dataset_names())
        return

    # -----------------------------
    # Optional: concatenate a human-bug dataset into training (train only)
    # -----------------------------
    human_bug_dataset_name: Optional[str] = (
        str(getattr(workflow_args_cfg, "human_bug_dataset_name", "")).strip() if workflow_args_cfg is not None else ""
    ) or None
    human_bug_dataset_split: str = (
        str(getattr(workflow_args_cfg, "human_bug_dataset_split", "train")) if workflow_args_cfg is not None else "train"
    )
    combined_train_dataset_name: Optional[str] = (
        str(getattr(workflow_args_cfg, "combined_train_dataset_name", "")).strip() if workflow_args_cfg is not None else ""
    ) or None

    if human_bug_dataset_name:
        bug_ds = DatasetRegistry.load_dataset(human_bug_dataset_name, human_bug_dataset_split)
        if bug_ds is None:
            raise ValueError(
                f"human_bug_dataset_name={human_bug_dataset_name!r} split={human_bug_dataset_split!r} not found."
            )

        main_tasks = train_dataset.get_data()
        bug_tasks = bug_ds.get_data()
        combined_tasks: List[Dict[str, Any]] = list(main_tasks) + list(bug_tasks)

        name = combined_train_dataset_name or f"bigcodebench_plus_{human_bug_dataset_name}"
        print(
            f"Combining training data: bigcodebench ({len(main_tasks)}) + "
            f"{human_bug_dataset_name}/{human_bug_dataset_split} ({len(bug_tasks)}) "
            f"= {len(combined_tasks)} -> registering '{name}'"
        )
        train_dataset = DatasetRegistry.register_dataset(name, combined_tasks, "train")

    # Optional prompts / eval flags
    generator_system_prompt = getattr(workflow_args_cfg, "generator_system_prompt", None) if workflow_args_cfg is not None else None
    solver_system_prompt = getattr(workflow_args_cfg, "solver_system_prompt", None) if workflow_args_cfg is not None else None
    evaluate_codegen = bool(getattr(workflow_args_cfg, "evaluate_codegen", True)) if workflow_args_cfg is not None else True

    # -----------------------------
    # SSR-like self-play knobs (forwarded to GeneratorSolverWorkflow)
    # -----------------------------
    solver_attempts_train = int(getattr(workflow_args_cfg, "solver_attempts_train", 8)) if workflow_args_cfg is not None else 8
    solver_attempts_val = int(getattr(workflow_args_cfg, "solver_attempts_val", 1)) if workflow_args_cfg is not None else 1

    generator_reward_mode = str(getattr(workflow_args_cfg, "generator_reward_mode", "band")) if workflow_args_cfg is not None else "band"
    solve_rate_band_low = float(getattr(workflow_args_cfg, "solve_rate_band_low", 0.05)) if workflow_args_cfg is not None else 0.05
    solve_rate_band_high = float(getattr(workflow_args_cfg, "solve_rate_band_high", 0.25)) if workflow_args_cfg is not None else 0.25
    gen_alpha_extreme = float(getattr(workflow_args_cfg, "gen_alpha_extreme", 0.2)) if workflow_args_cfg is not None else 0.2
    gen_invalid_bug_reward = float(getattr(workflow_args_cfg, "gen_invalid_bug_reward", -1.0)) if workflow_args_cfg is not None else -1.0

    solver_reward_pm1 = bool(getattr(workflow_args_cfg, "solver_reward_pm1", False)) if workflow_args_cfg is not None else False

    use_pregenerated_bugs_in_validation = bool(
        getattr(workflow_args_cfg, "use_pregenerated_bugs_in_validation", True)
    ) if workflow_args_cfg is not None else True

    use_pregenerated_bugs_in_training = bool(
        getattr(workflow_args_cfg, "use_pregenerated_bugs_in_training", bool(human_bug_dataset_name))
    ) if workflow_args_cfg is not None else bool(human_bug_dataset_name)
    pregenerated_bug_train_probability = float(
        getattr(workflow_args_cfg, "pregenerated_bug_train_probability", 1.0)
    ) if workflow_args_cfg is not None else 1.0

    episode_success_mode = str(getattr(workflow_args_cfg, "episode_success_mode", "bugfix")) if workflow_args_cfg is not None else "bugfix"

    # Include failed test output in solver prompts
    include_failed_test_output = bool(
        getattr(workflow_args_cfg, "include_failed_test_output", True)
    ) if workflow_args_cfg is not None else True

    # -----------------------------
    # Role-conditioned advantage options
    # -----------------------------
    freeze_generator = bool(getattr(workflow_args_cfg, "freeze_generator", False)) if workflow_args_cfg is not None else False
    freeze_solver = bool(getattr(workflow_args_cfg, "freeze_solver", False)) if workflow_args_cfg is not None else False
    use_role_advnorm = bool(getattr(workflow_args_cfg, "use_role_advnorm", False)) if workflow_args_cfg is not None else False

    frozen_roles = []
    if freeze_generator:
        frozen_roles.append("generator")
    if freeze_solver:
        frozen_roles.append("solver")

    # -----------------------------
    # LLM-as-judge for bug similarity (optional auxiliary reward for generator)
    # -----------------------------
    use_bug_similarity_judge = bool(getattr(workflow_args_cfg, "use_bug_similarity_judge", False)) if workflow_args_cfg is not None else False
    bug_similarity_reward_weight = float(getattr(workflow_args_cfg, "bug_similarity_reward_weight", 0.5)) if workflow_args_cfg is not None else 0.5
    bug_similarity_n_targets = int(getattr(workflow_args_cfg, "bug_similarity_n_targets", 3)) if workflow_args_cfg is not None else 3
    judge_system_prompt = getattr(workflow_args_cfg, "judge_system_prompt", None) if workflow_args_cfg is not None else None
    judge_base_url = (
        str(getattr(workflow_args_cfg, "judge_base_url", "")).strip() if workflow_args_cfg is not None else ""
    ) or None
    judge_model_name = (
        str(getattr(workflow_args_cfg, "judge_model_name", "")).strip() if workflow_args_cfg is not None else ""
    ) or None
    reference_bug_dataset = str(getattr(workflow_args_cfg, "reference_bug_dataset", "bugbench")) if workflow_args_cfg is not None else "bugbench"
    reference_bug_split = str(getattr(workflow_args_cfg, "reference_bug_split", "test")) if workflow_args_cfg is not None else "test"

    # Load reference bugs if LLM-as-judge is enabled
    reference_bugs: Optional[List[Dict[str, Any]]] = None
    if use_bug_similarity_judge:
        print(f"\n📊 Loading reference bugs for LLM-as-judge from {reference_bug_dataset}:{reference_bug_split}...")
        ref_bugs_data = load_data(dataset_name=reference_bug_dataset, split=reference_bug_split, n=1)
        if not ref_bugs_data:
            ds = DatasetRegistry.load_dataset(reference_bug_dataset, reference_bug_split)
            if ds:
                ref_bugs_data = list(ds.get_data())
        if ref_bugs_data:
            reference_bugs = ref_bugs_data
            print(f"  Loaded {len(reference_bugs)} reference bugs for similarity comparison")
            print(f"  Bug similarity reward weight: {bug_similarity_reward_weight}")
            print(f"  Number of targets per comparison: {bug_similarity_n_targets}")
        else:
            print(f"  WARNING: Could not load reference bugs. LLM-as-judge will be disabled.")
            use_bug_similarity_judge = False

    trainer = AgentTrainer(
        workflow_class=GeneratorSolverWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_system_prompt": generator_system_prompt,
            "solver_system_prompt": solver_system_prompt,
            "evaluate_codegen": evaluate_codegen,
            "solver_attempts_train": solver_attempts_train,
            "solver_attempts_val": solver_attempts_val,
            "generator_reward_mode": generator_reward_mode,
            "solve_rate_band_low": solve_rate_band_low,
            "solve_rate_band_high": solve_rate_band_high,
            "gen_alpha_extreme": gen_alpha_extreme,
            "gen_invalid_bug_reward": gen_invalid_bug_reward,
            "solver_reward_pm1": solver_reward_pm1,
            "use_pregenerated_bugs_in_validation": use_pregenerated_bugs_in_validation,
            "use_pregenerated_bugs_in_training": use_pregenerated_bugs_in_training,
            "pregenerated_bug_train_probability": pregenerated_bug_train_probability,
            "episode_success_mode": episode_success_mode,
            "include_failed_test_output": include_failed_test_output,
            # Role-conditioned advantage options
            "freeze_cm": bool(frozen_roles),
            "cm_roles": frozen_roles,
            "use_role_advnorm": use_role_advnorm,
            # LLM-as-judge for bug similarity (auxiliary reward for generator)
            "use_bug_similarity_judge": use_bug_similarity_judge,
            "bug_similarity_reward_weight": bug_similarity_reward_weight,
            "bug_similarity_n_targets": bug_similarity_n_targets,
            "judge_system_prompt": judge_system_prompt,
            "judge_base_url": judge_base_url,
            "judge_model_name": judge_model_name,
            "reference_bugs": reference_bugs,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        val_datasets=val_datasets,
    )

    try:
        trainer.train()
    except Exception as e:
        print(f"Training failed with error: {e}")
        raise
    finally:
        total_time = time.time() - start_time
        print("\n" + "=" * 60)
        print(f"Training completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Total run time: {format_time(total_time)} ({total_time:.2f} seconds)")
        print("=" * 60)


if __name__ == "__main__":
    main()


