from pathlib import Path
import os
import json
import hydra
import torch

@hydra.main(
    config_path="robobase/cfgs", config_name="robobase_config", version_base=None
)
def main(cfg):
    from robobase.policy_workspace import PolicyWorkspace

    root_dir = Path.cwd()

    workspace = PolicyWorkspace(cfg)

    print("=" * 80)
    print("Step 1: Sampling safe transitions from demos...")
    print("=" * 80)
    data_path = workspace._sample_safe_iql_data()
    print(f"Data saved to: {data_path}")

    if cfg.eval:
        print("Evaluating Safe-IQL Scheduling Policy...")
        metric = workspace._train_safe_iql_scheduling_policy(data_path=data_path, eval=True)
        print(metric)
        exit(0)
    
    print("=" * 80)
    print("Step 2: Training Safe-IQL Scheduling Policy...")
    print("=" * 80)
    agent = workspace._train_safe_iql_scheduling_policy(data_path=data_path)
    
    from omegaconf import OmegaConf
    safe_iql_cfg = OmegaConf.select(cfg, 'safe_iql', default={})
    eval_after_training = OmegaConf.select(safe_iql_cfg, 'eval_after_training', default=False)
    if eval_after_training:
        print("=" * 80)
        print("Step 3: Evaluating trained Safe-IQL agent...")
        print("=" * 80)
        num_episodes = OmegaConf.select(safe_iql_cfg, 'eval_episodes', default=50)
        metrics = workspace._eval_safe_iql_policy(agent, num_episodes=num_episodes, record_media=False)
        
        metrics_path = workspace.work_dir / "safe_iql_eval_metrics.json"
        def convert_to_serializable(obj):
            import numpy as np
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, dict):
                return {k: convert_to_serializable(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [convert_to_serializable(item) for item in obj]
            return obj
        
        serializable_metrics = convert_to_serializable(metrics)
        with open(metrics_path, 'w') as f:
            json.dump(serializable_metrics, f, indent=4)
        print(f"Evaluation metrics saved to: {metrics_path}")
    
    print("=" * 80)
    print("Safe-IQL training completed!")
    print("=" * 80)


if __name__ == "__main__":
    main()

