import omnisafe
import gymnasium as gym
import safety_gymnasium
import timeit
from typing import Any, Dict, Tuple
from omnisafe.envs.core import CMDP, env_register, env_unregister
from gymnasium import make
from UnsafeStateCounterWrapper import UnsafeStateCounterWrapper
from benchmarks.omnisafe_reg import SpiceEnvironment
from benchmarks.gymnasium_wrapper import GymnasiumWrapper
import torch
import argparse

# 1) set up a tiny “hidden” argparse just to steal --csv-path

if __name__ == "__main__":
    _parser = argparse.ArgumentParser(add_help=False)
    _parser.add_argument(
        "--csv-path",
        type=str,
        default="/scratch1/dsc5636/CPO/Ant/RCEPETS/run1.csv",
        help="where to dump your per-step CSV logs"
    )
    _parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="seed for the environment"
    )
    _parser.add_argument(
        "--env-id",
        type=str,
        default='Ant-v1',
        choices=["AccEnv-v1", "CarRacing-v1", "Pendulum-v1", "Cheetah-v1", "BipedalWalker-v1", "Hopper-v1", "Ant-v1","HumanoidEnv-v1"],
        help="environment id"
    )
    _parser.add_argument(
        "--algorithm",
        type=str,
        default='SafeLOOP',
        choices=['PPOSaute', 'PPOSimmerPID', 'CUP', 'P3O', 'CAPPETS', 'SafeLOOP', 'RCEPETS'],
        help="algorithm to use"
    )

    _parser.add_argument(
        "--steps",
        type=int,
        default=500000,
        help="total steps for training"
    )
    _args, _unknown = _parser.parse_known_args()


    @env_register
    @env_unregister
    class CustomSpiceEnv(GymnasiumWrapper):
        example_configs = 3

        def __init__(
            self,
            env_id: str,
            *,
            csv_path: str = 'run1.csv',
            **kwargs,
        ) -> None:
            # 1) call superclass init exactly as GymnasiumWrapper would, but
            #    now we inject our csv_path right here
            final_csv = _args.csv_path or csv_path
            super().__init__(env_id=env_id, csv_path=final_csv, **kwargs)


    #['BipedalWalker-v1', "AccEnv-v1", "CarRacing-v1", "Pendulum-v1", "Cheetah-v1"]


    algorithm = _args.algorithm
    env_id = _args.env_id
    run_id = f"{algorithm}_{env_id}_1"

    custom_cfgs = {
        'seed': _args.seed,
        'train_cfgs': {
            'total_steps': _args.steps,#200000,
            'vector_env_nums': 1,
            'parallel': 1,
            # 'device':'cuda:3',
        },
        'algo_cfgs': {
            'steps_per_epoch': 2000,
            # 'update_iters':8,
            'reward_normalize': False,
            'cost_normalize': False
        },
        # 'model_cfgs': {
        #     'std_range': [-1, 0]
        # },
        'logger_cfgs': {
            'use_wandb': False,
            'log_dir': './logs',
            'use_tensorboard': True
        },
        'evaluation_cfgs': {
            'use_eval': False
        }
    }

    # custom_cfgs.update({'env_cfgs': {'csv_path': run_id + '.csv'}})
    # Initialize the agent
    print("starting agent")
    agent = omnisafe.Agent(algorithm, env_id, custom_cfgs=custom_cfgs)


    t_0 = timeit.default_timer()
    agent.learn()
    t_1 = timeit.default_timer()
    # agent._save_model
    # agent._save_model()
    # agent.evaluate(1)

    print(f"Time for training: {round(t_1 - t_0)}")

