import time

import hydra
from omegaconf import OmegaConf

from examples.bugs.codegen_flow import CodeGenWorkflow
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.reward_fn import code_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


def format_time(seconds: float) -> str:
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)

    if hours > 0:
        return f"{hours}h {minutes}m {secs}s"
    if minutes > 0:
        return f"{minutes}m {secs}s"
    return f"{secs}s"


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    start_time = time.time()
    print(f"Starting training at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

    # Ensure workflow training is enabled.
    if not hasattr(config.rllm, "workflow"):
        config.rllm.workflow = OmegaConf.create({})
    config.rllm.workflow.use_workflow = True

    # Datasets:
    # - train: DeepCoder chunked (registered as deepcoder_chunked/train)
    # - val:   BigCodeBench test (registered as bigcodebench/test)
    try:
        train_dataset = DatasetRegistry.load_dataset("deepcoder_chunked", "train")
    except Exception as e:
        print(f"Failed to load dataset deepcoder_chunked/train: {e}")
        print("Available datasets:", DatasetRegistry.list_datasets())
        raise
    try:
        val_dataset = DatasetRegistry.load_dataset("bigcodebench", "test")
    except Exception as e:
        print(f"Failed to load dataset bigcodebench/test: {e}")
        print("Available datasets:", DatasetRegistry.list_datasets())
        raise

    # Optional system prompt from Hydra overrides.
    workflow_args_cfg = None
    if hasattr(config.rllm, "workflow") and hasattr(config.rllm.workflow, "workflow_args"):
        workflow_args_cfg = config.rllm.workflow.workflow_args
    system_prompt = getattr(workflow_args_cfg, "system_prompt", None) if workflow_args_cfg else None

    trainer = AgentTrainer(
        workflow_class=CodeGenWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "system_prompt": system_prompt,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
    )

    try:
        trainer.train()
    except Exception as e:
        print(f"Training failed with error: {e}")
        raise
    finally:
        total_time = time.time() - start_time
        print("\n" + "=" * 60)
        print(f"Training completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Total run time: {format_time(total_time)} ({total_time:.2f} seconds)")
        print("=" * 60)


if __name__ == "__main__":
    main()
