import os
import time
import hydra
from omegaconf import OmegaConf

from examples.bugs.generator_flow import BugGeneratorWorkflow
from examples.bugs.data_utils import register_deepcoder_chunked_dataset, register_bigcodebench_dataset
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.reward_fn import code_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


def format_time(seconds):
    """Format time in seconds to human readable format."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    
    if hours > 0:
        return f"{hours}h {minutes}m {secs}s"
    elif minutes > 0:
        return f"{minutes}m {secs}s"
    else:
        return f"{secs}s"


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    # Start timing
    start_time = time.time()
    print(f"Starting training at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Set use_workflow to True since we're using a workflow
    if not hasattr(config.rllm, 'workflow'):
        config.rllm.workflow = OmegaConf.create({})
    config.rllm.workflow.use_workflow = True
    
    # Get dataset name from config or default to deepcoder
    dataset_name = getattr(config.rllm.workflow.workflow_args, 'dataset_name', 'deepcoder') if hasattr(config.rllm, 'workflow') and hasattr(config.rllm.workflow, 'workflow_args') else 'deepcoder'
    
    # Register dataset based on name
    train_dataset = None
    test_dataset = None
    
    if dataset_name.lower() == "deepcoder":
        train_dataset, _ = register_deepcoder_chunked_dataset(return_test=True)
        test_dataset = DatasetRegistry.load_dataset("bigcodebench", "test")
    elif dataset_name.lower() == "bigcodebench":
        train_dataset = DatasetRegistry.load_dataset("bigcodebench", "train")
        test_dataset = DatasetRegistry.load_dataset("bigcodebench", "test")
    else:
        # Try to load custom dataset directly
        print(f"Loading custom dataset '{dataset_name}'...")
        train_dataset = DatasetRegistry.load_dataset(dataset_name, "train")
        test_dataset = DatasetRegistry.load_dataset(dataset_name, "test")
    
    if train_dataset is None:
        print(f"Failed to register/load dataset '{dataset_name}'. Exiting.")
        print("Available datasets:", DatasetRegistry.list_datasets())
        return
    
    # Get optional system prompts from config workflow_args if available
    workflow_args_cfg = None
    if hasattr(config.rllm, "workflow") and hasattr(config.rllm.workflow, "workflow_args"):
        workflow_args_cfg = config.rllm.workflow.workflow_args

    generator_system_prompt = getattr(workflow_args_cfg, "generator_system_prompt", None) if workflow_args_cfg else None
    solver_system_prompt = getattr(workflow_args_cfg, "solver_system_prompt", None) if workflow_args_cfg else None

    # Optional static solver knobs supported by BugGeneratorWorkflow
    solver_model = getattr(workflow_args_cfg, "solver_model", None) if workflow_args_cfg else None
    solver_base_url = getattr(workflow_args_cfg, "solver_base_url", None) if workflow_args_cfg else None
    solver_temperature = getattr(workflow_args_cfg, "solver_temperature", 0.0) if workflow_args_cfg else 0.0
    solver_top_p = getattr(workflow_args_cfg, "solver_top_p", 1.0) if workflow_args_cfg else 1.0
    solver_max_prompt_length = getattr(workflow_args_cfg, "solver_max_prompt_length", None) if workflow_args_cfg else None
    solver_max_response_length = getattr(workflow_args_cfg, "solver_max_response_length", None) if workflow_args_cfg else None
    
    # Note: During training, the solver_rollout_engine will be None by default.
    # The generator will be trained to create bugs that fail tests.
    # For evaluation against a static solver (e.g., GPT-4o-mini), use run_generator_flow.py
    # with --solver_model flag.
    # 
    # If you want to train against a reference model solver during training, you can:
    # 1. Set config.solver_model_path to use a separate frozen solver model
    # 2. Or modify this script to create a custom trainer similar to train_solver_cm_flow_separate.py
    
    # Create trainer with BugGeneratorWorkflow
    trainer = AgentTrainer(
        workflow_class=BugGeneratorWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_system_prompt": generator_system_prompt,
            "solver_system_prompt": solver_system_prompt,
            # If provided, BugGeneratorWorkflow will construct a static OpenAI-compatible solver engine.
            # (e.g., rllm.workflow.workflow_args.solver_model=gpt-4o-mini)
            "solver_model": solver_model,
            "solver_base_url": solver_base_url,
            "solver_temperature": solver_temperature,
            "solver_top_p": solver_top_p,
            "solver_max_prompt_length": solver_max_prompt_length,
            "solver_max_response_length": solver_max_response_length,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=test_dataset,
    )
    
    try:
        trainer.train()
    except Exception as e:
        print(f"Training failed with error: {e}")
        raise
    finally:
        # Calculate and print total run time
        end_time = time.time()
        total_time = end_time - start_time
        
        print("\n" + "="*60)
        print(f"Training completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Total run time: {format_time(total_time)} ({total_time:.2f} seconds)")
        print("="*60)


if __name__ == "__main__":
    main()
