import time

import random
from copy import deepcopy
from typing import Any, Dict, List, Optional

import hydra
from omegaconf import OmegaConf

from examples.bugs.solver_flow import SolverWorkflow
from examples.bugs.data_utils import parse_val_dataset_specs
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.reward_fn import code_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


def build_mixed_training_tasks(
    *,
    main_tasks: List[Dict[str, Any]],
    bug_tasks: List[Dict[str, Any]],
    bug_mix_prob: float,
    seed: int = 0,
) -> List[Dict[str, Any]]:
    """Create a mixed task list by replacing some main tasks with bug tasks.

    - Output length equals len(main_tasks).
    - Each position is a bug example with probability bug_mix_prob (sampled w/ replacement).
    """
    p = max(0.0, min(1.0, float(bug_mix_prob)))
    if p <= 0.0 or not bug_tasks:
        return main_tasks

    rng = random.Random(int(seed))
    mixed: List[Dict[str, Any]] = []
    for t in main_tasks:
        if rng.random() < p:
            mixed.append(deepcopy(rng.choice(bug_tasks)))
        else:
            mixed.append(t)
    rng.shuffle(mixed)
    return mixed


def format_time(seconds):
    """Format time in seconds to human readable format."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)

    if hours > 0:
        return f"{hours}h {minutes}m {secs}s"
    if minutes > 0:
        return f"{minutes}m {secs}s"
    return f"{secs}s"


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    start_time = time.time()
    print(f"Starting training at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

    # Enable workflow training.
    if not hasattr(config.rllm, "workflow"):
        config.rllm.workflow = OmegaConf.create({})
    config.rllm.workflow.use_workflow = True

    # Load dataset
    dataset_name = getattr(config.rllm.workflow.workflow_args, "dataset_name", "deepcoder_bugs")
    train_dataset = DatasetRegistry.load_dataset(dataset_name, "train")

    # Pull workflow args from Hydra overrides if present.
    workflow_args_cfg = None
    if hasattr(config.rllm, "workflow") and hasattr(config.rllm.workflow, "workflow_args"):
        workflow_args_cfg = config.rllm.workflow.workflow_args

    # Optional: validate on multiple datasets (evaluated separately; no concatenation).
    val_specs = parse_val_dataset_specs(workflow_args_cfg)
    if val_specs:
        val_datasets = {}
        for alias, (ds_name, split) in val_specs.items():
            ds = DatasetRegistry.load_dataset(ds_name, split)
            if ds is None:
                raise ValueError(f"val dataset not found: {ds_name!r} split={split!r} (alias={alias!r})")
            val_datasets[alias] = ds
        val_dataset = None
        print(f"Using multiple val datasets: {list(val_datasets.keys())}")
    else:
        # Backward-compatible default.
        val_dataset = DatasetRegistry.load_dataset("bugbench", "test")
        val_datasets = None

    # -----------------------------
    # NEW: Mix in a human-bug dataset into training (optional)
    # -----------------------------
    human_bug_dataset_name: Optional[str] = (
        str(getattr(workflow_args_cfg, "human_bug_dataset_name", "")).strip() if workflow_args_cfg is not None else ""
    ) or None
    human_bug_dataset_split: str = (
        str(getattr(workflow_args_cfg, "human_bug_dataset_split", "train")) if workflow_args_cfg is not None else "train"
    )
    human_bug_mix_prob: float = float(
        getattr(workflow_args_cfg, "human_bug_mix_prob", 0.0) if workflow_args_cfg is not None else 0.0
    )
    human_bug_mix_seed: int = int(
        getattr(workflow_args_cfg, "human_bug_mix_seed", 0) if workflow_args_cfg is not None else 0
    )
    mixed_train_dataset_name: Optional[str] = (
        str(getattr(workflow_args_cfg, "mixed_train_dataset_name", "")).strip() if workflow_args_cfg is not None else ""
    ) or None

    if human_bug_dataset_name and human_bug_mix_prob > 0.0:
        bug_ds = DatasetRegistry.load_dataset(human_bug_dataset_name, human_bug_dataset_split)
        if bug_ds is None:
            raise ValueError(
                f"human_bug_dataset_name={human_bug_dataset_name!r} split={human_bug_dataset_split!r} not found."
            )

        main_tasks = train_dataset.get_data()
        bug_tasks = bug_ds.get_data()
        mixed_tasks = build_mixed_training_tasks(
            main_tasks=main_tasks,
            bug_tasks=bug_tasks,
            bug_mix_prob=human_bug_mix_prob,
            seed=human_bug_mix_seed,
        )

        mix_name = mixed_train_dataset_name or f"{dataset_name}_mix_{human_bug_dataset_name}_p{human_bug_mix_prob:g}"
        print(
            f"Mixing training data: main={dataset_name} ({len(main_tasks)}) + "
            f"bugs={human_bug_dataset_name}/{human_bug_dataset_split} ({len(bug_tasks)}), "
            f"p={human_bug_mix_prob} -> registering '{mix_name}'"
        )
        train_dataset = DatasetRegistry.register_dataset(mix_name, mixed_tasks, "train")

    # Static generator knobs (required for SolverWorkflow).
    generator_model = getattr(workflow_args_cfg, "generator_model", None) if workflow_args_cfg else None
    generator_base_url = getattr(workflow_args_cfg, "generator_base_url", None) if workflow_args_cfg else None
    generator_api_key = getattr(workflow_args_cfg, "generator_api_key", None) if workflow_args_cfg else None
    generator_temperature = getattr(workflow_args_cfg, "generator_temperature", 0.6) if workflow_args_cfg else 0.6
    generator_top_p = getattr(workflow_args_cfg, "generator_top_p", 0.95) if workflow_args_cfg else 0.95
    generator_system_prompt = getattr(workflow_args_cfg, "generator_system_prompt", None) if workflow_args_cfg else None

    # Solver knobs.
    solver_system_prompt = getattr(workflow_args_cfg, "solver_system_prompt", None) if workflow_args_cfg else None
    compile_errors_invalid = getattr(workflow_args_cfg, "compile_errors_invalid", True) if workflow_args_cfg else True

    # Training behavior: if tasks contain buggy_code, optionally use them.
    use_pregenerated_bugs_in_training = bool(
        getattr(workflow_args_cfg, "use_pregenerated_bugs_in_training", human_bug_mix_prob > 0.0)
    ) if workflow_args_cfg is not None else (human_bug_mix_prob > 0.0)
    pregenerated_bug_train_probability = float(
        getattr(workflow_args_cfg, "pregenerated_bug_train_probability", 1.0)
    ) if workflow_args_cfg is not None else 1.0

    trainer = AgentTrainer(
        workflow_class=SolverWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_model": generator_model,
            "generator_base_url": generator_base_url,
            "generator_api_key": generator_api_key,
            "generator_temperature": generator_temperature,
            "generator_top_p": generator_top_p,
            "generator_system_prompt": generator_system_prompt,
            "solver_system_prompt": solver_system_prompt,
            "compile_errors_invalid": compile_errors_invalid,
            "use_pregenerated_bugs_in_training": use_pregenerated_bugs_in_training,
            "pregenerated_bug_train_probability": pregenerated_bug_train_probability,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        val_datasets=val_datasets,
    )

    try:
        trainer.train()
    except Exception as e:
        print(f"Training failed with error: {e}")
        raise
    finally:
        total_time = time.time() - start_time
        print("\n" + "=" * 60)
        print(f"Training completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Total run time: {format_time(total_time)} ({total_time:.2f} seconds)")
        print("=" * 60)


if __name__ == "__main__":
    main()


