import json
import time
from typing import Any, Dict, List, Optional

import hydra
from omegaconf import OmegaConf

from examples.bugs.generator_solver_flow import GeneratorSolverWorkflow
from examples.bugs.data_utils import (
    load_data,
    parse_val_dataset_specs,
    register_deepcoder_chunked_dataset,
    register_bigcodebench_dataset,
)
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.reward_fn import code_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


def normalize_task_for_parquet(task: Dict[str, Any]) -> Dict[str, Any]:
    """Normalize a task dict to ensure consistent types for parquet serialization.
    
    Converts complex fields (lists, dicts) to JSON strings to avoid pyarrow type conflicts
    when mixing datasets with different schemas.
    
    Some datasets (e.g., DeepCoder) already have JSON-encoded strings for these fields,
    while others (e.g., BugBench) have raw dicts/lists. This function normalizes both
    to JSON strings without double-encoding.
    """
    result = {}
    # Fields that may have inconsistent types across datasets
    complex_fields = {"ground_truth", "test", "test_list", "tests", "metadata", "extra_info"}
    
    for key, value in task.items():
        if key in complex_fields and value is not None:
            if isinstance(value, (list, dict)):
                # Serialize complex fields to JSON strings
                result[key] = json.dumps(value)
            elif isinstance(value, str):
                # Already a string - keep as is (may already be JSON-encoded)
                result[key] = value
            else:
                # Other types: convert to string
                result[key] = str(value) if value is not None else None
        else:
            result[key] = value
    
    return result


def load_and_mix_datasets(
    main_dataset: str,
    main_split: str,
    target_dataset: str,
    target_split: str,
) -> List[Dict[str, Any]]:
    """Load and mix main dataset with target dataset (pregenerated bugs).
    
    Same logic as run_generator_solver_flow_mixed.py but with normalization
    for parquet compatibility when registering the combined dataset.
    """
    print(f"\n📦 Loading datasets...")
    
    # Load main dataset
    print(f"  Main dataset: {main_dataset}:{main_split}")
    main_tasks = load_data(dataset_name=main_dataset, split=main_split, n=1)
    if not main_tasks:
        # Try registry
        ds = DatasetRegistry.load_dataset(main_dataset, main_split)
        if ds:
            main_tasks = list(ds.get_data())
    
    if not main_tasks:
        print(f"  ERROR: Could not load main dataset {main_dataset}:{main_split}")
        return []
    
    print(f"    Loaded {len(main_tasks)} main tasks")
    
    # Load target dataset (pregenerated bugs)
    print(f"  Target dataset: {target_dataset}:{target_split}")
    target_tasks = load_data(dataset_name=target_dataset, split=target_split, n=1)
    if not target_tasks:
        ds = DatasetRegistry.load_dataset(target_dataset, target_split)
        if ds:
            target_tasks = list(ds.get_data())
    
    if not target_tasks:
        print(f"  WARNING: Could not load target dataset {target_dataset}:{target_split}")
        target_tasks = []
    
    print(f"    Loaded {len(target_tasks)} target tasks")
    
    # Normalize tasks to ensure consistent types for parquet serialization
    # This is necessary because different datasets may have different schemas
    normalized_main = [normalize_task_for_parquet(t) for t in main_tasks]
    normalized_target = [normalize_task_for_parquet(t) for t in target_tasks]
    
    # Mix
    combined = normalized_main + normalized_target
    print(f"  Combined: {len(combined)} tasks total")
    
    return combined


def format_time(seconds: float) -> str:
    """Format time in seconds to human readable format."""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)

    if hours > 0:
        return f"{hours}h {minutes}m {secs}s"
    if minutes > 0:
        return f"{minutes}m {secs}s"
    return f"{secs}s"


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    start_time = time.time()
    print(f"Starting training at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

    # Ensure workflow config exists and enable workflow mode
    if not hasattr(config.rllm, "workflow"):
        config.rllm.workflow = OmegaConf.create({})
    config.rllm.workflow.use_workflow = True

    # Resolve workflow_args safely
    workflow_args_cfg = None
    if hasattr(config.rllm, "workflow") and hasattr(config.rllm.workflow, "workflow_args"):
        workflow_args_cfg = config.rllm.workflow.workflow_args

    # -----------------------------
    # Load and mix datasets (same pattern as run_generator_solver_flow_mixed.py)
    # -----------------------------
    dataset_name = getattr(workflow_args_cfg, "dataset_name", "bigcodebench") if workflow_args_cfg is not None else "bigcodebench"
    target_dataset = getattr(workflow_args_cfg, "target_dataset", "bugbench") if workflow_args_cfg is not None else "bugbench"
    target_split = getattr(workflow_args_cfg, "target_split", "train0.5") if workflow_args_cfg is not None else "train0.5"
    
    combined_tasks = load_and_mix_datasets(
        main_dataset=str(dataset_name),
        main_split="train",
        target_dataset=str(target_dataset),
        target_split=str(target_split),
    )
    
    if not combined_tasks:
        print("Failed to load datasets. Exiting.")
        print("Available datasets:", DatasetRegistry.get_dataset_names())
        return

    mixed_train_dataset_name: Optional[str] = (
        str(getattr(workflow_args_cfg, "mixed_train_dataset_name", "")).strip() if workflow_args_cfg is not None else ""
    ) or None
    mix_name = mixed_train_dataset_name or f"{dataset_name}_plus_{target_dataset}_{target_split}"
    print(f"Registering combined dataset as '{mix_name}'...")
    train_dataset = DatasetRegistry.register_dataset(mix_name, combined_tasks, "train")

    # -----------------------------
    # Validation datasets
    # -----------------------------
    # Support for multiple val datasets via val_datasets config, else default to bugbench/test0.5
    val_specs = parse_val_dataset_specs(workflow_args_cfg)
    if val_specs:
        val_datasets: dict[str, Any] = {}
        for alias, (ds_name, split) in val_specs.items():
            ds = DatasetRegistry.load_dataset(ds_name, split)
            if ds is None:
                raise ValueError(f"val dataset not found: {ds_name!r} split={split!r} (alias={alias!r})")
            val_datasets[alias] = ds
        val_dataset = None
        print(f"Using multiple val datasets: {list(val_datasets.keys())}")
    else:
        val_dataset = DatasetRegistry.load_dataset("bugbench", "test0.5")
        if val_dataset is None:
            raise ValueError("Missing BugBench test0.5 split (expected `bugbench/test0.5`).")
        val_datasets = None

    # Optional prompts / eval flags
    generator_system_prompt = getattr(workflow_args_cfg, "generator_system_prompt", None) if workflow_args_cfg is not None else None
    solver_system_prompt = getattr(workflow_args_cfg, "solver_system_prompt", None) if workflow_args_cfg is not None else None
    evaluate_codegen = bool(getattr(workflow_args_cfg, "evaluate_codegen", True)) if workflow_args_cfg is not None else True

    # SSR-like knobs
    solver_attempts_train = int(getattr(workflow_args_cfg, "solver_attempts_train", 8)) if workflow_args_cfg is not None else 8
    solver_attempts_val = int(getattr(workflow_args_cfg, "solver_attempts_val", 1)) if workflow_args_cfg is not None else 1

    generator_reward_mode = str(getattr(workflow_args_cfg, "generator_reward_mode", "band")) if workflow_args_cfg is not None else "band"
    solve_rate_band_low = float(getattr(workflow_args_cfg, "solve_rate_band_low", 0.05)) if workflow_args_cfg is not None else 0.05
    solve_rate_band_high = float(getattr(workflow_args_cfg, "solve_rate_band_high", 0.25)) if workflow_args_cfg is not None else 0.25
    gen_alpha_extreme = float(getattr(workflow_args_cfg, "gen_alpha_extreme", 0.2)) if workflow_args_cfg is not None else 0.2
    gen_invalid_bug_reward = float(getattr(workflow_args_cfg, "gen_invalid_bug_reward", -1.0)) if workflow_args_cfg is not None else -1.0

    solver_reward_pm1 = bool(getattr(workflow_args_cfg, "solver_reward_pm1", False)) if workflow_args_cfg is not None else False

    # BugBench test0.5 contains pregenerated bugs; enable that by default.
    use_pregenerated_bugs_in_validation = bool(
        getattr(workflow_args_cfg, "use_pregenerated_bugs_in_validation", True)
    ) if workflow_args_cfg is not None else True

    # Our mixed-in training data contains pregenerated bugs too; enable by default.
    use_pregenerated_bugs_in_training = bool(
        getattr(workflow_args_cfg, "use_pregenerated_bugs_in_training", True)
    ) if workflow_args_cfg is not None else True
    pregenerated_bug_train_probability = float(
        getattr(workflow_args_cfg, "pregenerated_bug_train_probability", 0.5)
    ) if workflow_args_cfg is not None else 0.5

    episode_success_mode = str(getattr(workflow_args_cfg, "episode_success_mode", "bugfix")) if workflow_args_cfg is not None else "bugfix"

    # Include failed test output in solver prompts
    include_failed_test_output = bool(
        getattr(workflow_args_cfg, "include_failed_test_output", True)
    ) if workflow_args_cfg is not None else True

    # -----------------------------
    # Role-conditioned advantage options
    # -----------------------------
    freeze_generator = bool(getattr(workflow_args_cfg, "freeze_generator", False)) if workflow_args_cfg is not None else False
    freeze_solver = bool(getattr(workflow_args_cfg, "freeze_solver", False)) if workflow_args_cfg is not None else False
    use_role_advnorm = bool(getattr(workflow_args_cfg, "use_role_advnorm", False)) if workflow_args_cfg is not None else False

    frozen_roles = []
    if freeze_generator:
        frozen_roles.append("generator")
    if freeze_solver:
        frozen_roles.append("solver")

    trainer = AgentTrainer(
        workflow_class=GeneratorSolverWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_system_prompt": generator_system_prompt,
            "solver_system_prompt": solver_system_prompt,
            "evaluate_codegen": evaluate_codegen,
            "solver_attempts_train": solver_attempts_train,
            "solver_attempts_val": solver_attempts_val,
            "generator_reward_mode": generator_reward_mode,
            "solve_rate_band_low": solve_rate_band_low,
            "solve_rate_band_high": solve_rate_band_high,
            "gen_alpha_extreme": gen_alpha_extreme,
            "gen_invalid_bug_reward": gen_invalid_bug_reward,
            "solver_reward_pm1": solver_reward_pm1,
            "use_pregenerated_bugs_in_validation": use_pregenerated_bugs_in_validation,
            "use_pregenerated_bugs_in_training": use_pregenerated_bugs_in_training,
            "pregenerated_bug_train_probability": pregenerated_bug_train_probability,
            "episode_success_mode": episode_success_mode,
            "include_failed_test_output": include_failed_test_output,
            # Role-conditioned advantage options
            "freeze_cm": bool(frozen_roles),
            "cm_roles": frozen_roles,
            "use_role_advnorm": use_role_advnorm,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        val_datasets=val_datasets,
    )

    try:
        trainer.train()
    except Exception as e:
        print(f"Training failed with error: {e}")
        raise
    finally:
        total_time = time.time() - start_time
        print("\n" + "=" * 60)
        print(f"Training completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Total run time: {format_time(total_time)} ({total_time:.2f} seconds)")
        print("=" * 60)


if __name__ == "__main__":
    main()


