import os
import time
from typing import Any, Dict, List, Optional

import hydra
from omegaconf import OmegaConf

from examples.bugs.generator_solver_flow import GeneratorSolverWorkflow
from examples.bugs.data_utils import (
    parse_val_dataset_specs,
    register_bigcodebench_dataset,
    register_deepcoder_chunked_dataset,
)
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.reward_fn import code_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


def format_time(seconds: float) -> str:
    """Format time in seconds to human readable format."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)

    if hours > 0:
        return f"{hours}h {minutes}m {secs}s"
    elif minutes > 0:
        return f"{minutes}m {secs}s"
    else:
        return f"{secs}s"


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    start_time = time.time()
    print(f"Starting training at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

    # Ensure workflow config exists and enable workflow mode
    if not hasattr(config.rllm, "workflow"):
        config.rllm.workflow = OmegaConf.create({})
    config.rllm.workflow.use_workflow = True

    # Resolve workflow_args safely
    workflow_args_cfg = None
    if hasattr(config.rllm, "workflow") and hasattr(config.rllm.workflow, "workflow_args"):
        workflow_args_cfg = config.rllm.workflow.workflow_args

    # Dataset selection
    dataset_name = getattr(workflow_args_cfg, "dataset_name", "deepcoder") if workflow_args_cfg is not None else "deepcoder"

    train_dataset = None
    if dataset_name.lower() == "deepcoder":
        train_dataset, _ = register_deepcoder_chunked_dataset(return_test=True)
    elif dataset_name.lower() == "bigcodebench":
        train_dataset = register_bigcodebench_dataset()
    else:
        print(f"Loading custom dataset '{dataset_name}'...")
        train_dataset = DatasetRegistry.load_dataset(dataset_name, "train")

    # Validation datasets:
    # - If workflow_args.val_datasets is provided, evaluate on each separately (no concatenation).
    # - Else fall back to single val_dataset_name/val_dataset_split (default: bugbench:test).
    val_specs = parse_val_dataset_specs(workflow_args_cfg)
    if val_specs:
        val_datasets: dict[str, Any] = {}
        for alias, (ds_name, split) in val_specs.items():
            ds = DatasetRegistry.load_dataset(ds_name, split)
            if ds is None:
                raise ValueError(f"val dataset not found: {ds_name!r} split={split!r} (alias={alias!r})")
            val_datasets[alias] = ds
        val_dataset = None
        print(f"Using multiple val datasets: {list(val_datasets.keys())}")
    else:
        val_dataset_name = (
            str(getattr(workflow_args_cfg, "val_dataset_name", "bugbench")) if workflow_args_cfg is not None else "bugbench"
        )
        val_dataset_split = (
            str(getattr(workflow_args_cfg, "val_dataset_split", "test")) if workflow_args_cfg is not None else "test"
        )
        val_dataset = DatasetRegistry.load_dataset(val_dataset_name, val_dataset_split)
        val_datasets = None

    if train_dataset is None:
        print(f"Failed to register/load dataset '{dataset_name}'. Exiting.")
        print("Available datasets:", DatasetRegistry.get_dataset_names())
        return

    # -----------------------------
    # NEW: Concatenate a human-bug dataset into training (optional)
    # -----------------------------
    # Expectation: the bug dataset examples contain a key like buggy_solution/buggy_code/buggy
    # so GeneratorSolverWorkflow can skip generation and directly prompt the solver.
    human_bug_dataset_name: Optional[str] = (str(getattr(workflow_args_cfg, "human_bug_dataset_name", "")).strip() or None)
    human_bug_dataset_split: str = (str(getattr(workflow_args_cfg, "human_bug_dataset_split", "train")) or "train")
    mixed_train_dataset_name: Optional[str] = (str(getattr(workflow_args_cfg, "mixed_train_dataset_name", "")).strip() or None)

    if human_bug_dataset_name:
        bug_ds = DatasetRegistry.load_dataset(human_bug_dataset_name, human_bug_dataset_split)
        if bug_ds is None:
            raise ValueError(
                f"human_bug_dataset_name={human_bug_dataset_name!r} split={human_bug_dataset_split!r} not found."
            )

        main_tasks = train_dataset.get_data()
        bug_tasks = bug_ds.get_data()

        combined_tasks: List[Dict[str, Any]] = list(main_tasks) + list(bug_tasks)

        mix_name = mixed_train_dataset_name or f"{dataset_name}_plus_{human_bug_dataset_name}"
        print(
            f"Combining training data: main={dataset_name} ({len(main_tasks)}) + "
            f"bugs={human_bug_dataset_name}/{human_bug_dataset_split} ({len(bug_tasks)}) "
            f"= {len(combined_tasks)} -> registering '{mix_name}'"
        )
        train_dataset = DatasetRegistry.register_dataset(mix_name, combined_tasks, "train")

    # Optional prompts / eval flags
    generator_system_prompt = getattr(workflow_args_cfg, "generator_system_prompt", None) if workflow_args_cfg is not None else None
    solver_system_prompt = getattr(workflow_args_cfg, "solver_system_prompt", None) if workflow_args_cfg is not None else None
    evaluate_codegen = bool(getattr(workflow_args_cfg, "evaluate_codegen", True)) if workflow_args_cfg is not None else True

    # -----------------------------
    # NEW: SSR-like self-play knobs
    # -----------------------------
    # Solver attempts (K) used to compute solve_rate during training
    solver_attempts_train = int(getattr(workflow_args_cfg, "solver_attempts_train", 8)) if workflow_args_cfg is not None else 8
    # Keep validation cheaper by default
    solver_attempts_val = int(getattr(workflow_args_cfg, "solver_attempts_val", 1)) if workflow_args_cfg is not None else 1

    # Generator reward shaping (difficulty frontier)
    generator_reward_mode = str(getattr(workflow_args_cfg, "generator_reward_mode", "band")) if workflow_args_cfg is not None else "band"
    solve_rate_band_low = float(getattr(workflow_args_cfg, "solve_rate_band_low", 0.05)) if workflow_args_cfg is not None else 0.05
    solve_rate_band_high = float(getattr(workflow_args_cfg, "solve_rate_band_high", 0.25)) if workflow_args_cfg is not None else 0.25
    gen_alpha_extreme = float(getattr(workflow_args_cfg, "gen_alpha_extreme", 0.2)) if workflow_args_cfg is not None else 0.2
    gen_invalid_bug_reward = float(getattr(workflow_args_cfg, "gen_invalid_bug_reward", -1.0)) if workflow_args_cfg is not None else -1.0

    # Reward style for solver attempts
    solver_reward_pm1 = bool(getattr(workflow_args_cfg, "solver_reward_pm1", False)) if workflow_args_cfg is not None else False

    # Validation behavior: use pregenerated bugs from BugBench
    use_pregenerated_bugs_in_validation = bool(
        getattr(workflow_args_cfg, "use_pregenerated_bugs_in_validation", True)
    ) if workflow_args_cfg is not None else True

    # Training behavior: if tasks contain buggy_code, optionally skip generator and solve them.
    use_pregenerated_bugs_in_training = bool(
        getattr(workflow_args_cfg, "use_pregenerated_bugs_in_training", bool(human_bug_dataset_name))
    ) if workflow_args_cfg is not None else bool(human_bug_dataset_name)
    pregenerated_bug_train_probability = float(
        getattr(workflow_args_cfg, "pregenerated_bug_train_probability", 1.0)
    ) if workflow_args_cfg is not None else 1.0

    # What val success means
    episode_success_mode = str(getattr(workflow_args_cfg, "episode_success_mode", "bugfix")) if workflow_args_cfg is not None else "bugfix"

    # Include failed test output in solver prompts
    include_failed_test_output = bool(
        getattr(workflow_args_cfg, "include_failed_test_output", True)
    ) if workflow_args_cfg is not None else True

    # -----------------------------
    # Role-conditioned advantage options
    # -----------------------------
    # freeze_generator: zero out generator advantages (train solver only)
    freeze_generator = bool(getattr(workflow_args_cfg, "freeze_generator", False)) if workflow_args_cfg is not None else False
    # freeze_solver: zero out solver advantages (train generator only)
    freeze_solver = bool(getattr(workflow_args_cfg, "freeze_solver", False)) if workflow_args_cfg is not None else False
    # use_role_advnorm: normalize advantages per role (PAG-style)
    use_role_advnorm = bool(getattr(workflow_args_cfg, "use_role_advnorm", False)) if workflow_args_cfg is not None else False

    # Build list of roles to freeze
    frozen_roles = []
    if freeze_generator:
        frozen_roles.append("generator")
    if freeze_solver:
        frozen_roles.append("solver")

    # Create trainer with SSR-like workflow args
    trainer = AgentTrainer(
        workflow_class=GeneratorSolverWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_system_prompt": generator_system_prompt,
            "solver_system_prompt": solver_system_prompt,
            "evaluate_codegen": evaluate_codegen,

            # SSR-like knobs forwarded to GeneratorSolverWorkflow
            "solver_attempts_train": solver_attempts_train,
            "solver_attempts_val": solver_attempts_val,
            "generator_reward_mode": generator_reward_mode,
            "solve_rate_band_low": solve_rate_band_low,
            "solve_rate_band_high": solve_rate_band_high,
            "gen_alpha_extreme": gen_alpha_extreme,
            "gen_invalid_bug_reward": gen_invalid_bug_reward,
            "solver_reward_pm1": solver_reward_pm1,
            "use_pregenerated_bugs_in_validation": use_pregenerated_bugs_in_validation,
            "use_pregenerated_bugs_in_training": use_pregenerated_bugs_in_training,
            "pregenerated_bug_train_probability": pregenerated_bug_train_probability,
            "episode_success_mode": episode_success_mode,
            "include_failed_test_output": include_failed_test_output,

            # Role-conditioned advantage options (passed to AgentWorkflowPPOTrainer)
            "freeze_cm": bool(frozen_roles),  # enable freezing if any roles frozen
            "cm_roles": frozen_roles,         # roles to freeze
            "use_role_advnorm": use_role_advnorm,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        val_datasets=val_datasets,
    )

    try:
        trainer.train()
    except Exception as e:
        print(f"Training failed with error: {e}")
        raise
    finally:
        end_time = time.time()
        total_time = end_time - start_time

        print("\n" + "=" * 60)
        print(f"Training completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Total run time: {format_time(total_time)} ({total_time:.2f} seconds)")
        print("=" * 60)


if __name__ == "__main__":
    main()

