import omnisafe
import gymnasium as gym
import safety_gymnasium
import timeit
from typing import Any, Dict, Tuple
from omnisafe.envs.core import CMDP, env_register, env_unregister
from gymnasium import make
from UnsafeStateCounterWrapper import UnsafeStateCounterWrapper
from benchmarks.omnisafe_reg import SpiceEnvironment
from benchmarks.gymnasium_wrapper import GymnasiumWrapper
import torch
import argparse

# 1) set up a tiny “hidden” argparse just to steal --csv-path
_parser = argparse.ArgumentParser(add_help=False)
_parser.add_argument(
    "--csv-path",
    type=str,
    default=None,
    help="where to dump your per-step CSV logs"
)
_args, _unknown = _parser.parse_known_args()


@env_register
@env_unregister
class CustomSpiceEnv(GymnasiumWrapper):
    example_configs = 3

    def __init__(
        self,
        env_id: str,
        *,
        csv_path: str = 'run1.csv',
        **kwargs,
    ) -> None:
        # 1) call superclass init exactly as GymnasiumWrapper would, but
        #    now we inject our csv_path right here
        final_csv = _args.csv_path or csv_path
        super().__init__(env_id=env_id, csv_path=final_csv, **kwargs)



algorithm = 'PPOSimmerPID'
env_id = 'Cheetah-v1'
run_id = f"{algorithm}_{env_id}_1"

custom_cfgs = {
    'train_cfgs': {
        'total_steps': 500000,
        'vector_env_nums': 1,
        'parallel': 1,
        # 'device':'cuda:3',
    },
    'algo_cfgs': {
        'steps_per_epoch': 2000,
        'update_iters':8,
        'reward_normalize': False,
        'cost_normalize': False
    },
    'logger_cfgs': {
        'use_wandb': False,
        'log_dir': './logs',
        'use_tensorboard': True
    }
}

# custom_cfgs.update({'env_cfgs': {'csv_path': run_id + '.csv'}})
# Initialize the agent
agent = omnisafe.Agent(algorithm, env_id, custom_cfgs=custom_cfgs)

t_0 = timeit.default_timer()
agent.learn()
t_1 = timeit.default_timer()
# agent._save_model
# agent._save_model()
# agent.evaluate(1)

print(f"Time for training: {round(t_1 - t_0)}")

