import omnisafe
import gymnasium as gym
import safety_gymnasium
import timeit
from typing import Any, Dict, Tuple
from omnisafe.envs.core import CMDP, env_register, env_unregister
from gymnasium import make
from UnsafeStateCounterWrapper import UnsafeStateCounterWrapper
from benchmarks.omnisafe_reg import SpiceEnvironment
from benchmarks.gymnasium_wrapper import GymnasiumWrapper

@env_register
@env_unregister
class CustomSpiceEnv(GymnasiumWrapper):
    example_configs = 2


# env = CustomSpiceEnv('Example-v0')
# env = CustomSpiceEnv('LunarLander-v1')

custom_cfgs = {
    'train_cfgs': {
        'total_steps': 1024,
        'vector_env_nums': 1,
        'parallel': 1
    },
    'algo_cfgs': {
        'steps_per_epoch': 1024,
        'update_iters': 10,
    },
    'logger_cfgs': {
        'use_wandb': False,
        'log_dir': './logs'
    }
}

# Initialize the agent
agent = omnisafe.Agent('CPO', "CustomLunarLander-v2", custom_cfgs=custom_cfgs)

t_0 = timeit.default_timer()
agent.learn()
t_1 = timeit.default_timer()

agent.evaluate(100000)
print(f"Time for training: {round(t_1 - t_0)}")
# # Evaluate the agent
# agent.evaluate(num_episodes=500)
# print(f"Total number of unsafe states for trained policy: {agent.env._env.total_unsafe_states}")
