import os

from spinup import cup_pytorch as cup
from spinup import cppo_pytorch as cppo
from spinup import ppolag_pytorch as ppolag

from spinup import fuzcup_pytorch as fuzcup
from spinup import fuzcppo_pytorch as fuzcppo
from spinup import fuzppolag_pytorch as fuzppolag

import argparse
from envs.DoubleIntegratorEnv import DoubleIntegratorEnv

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0, help='Random seed')
parser.add_argument('--len', type=int, default=100, help='Max Episode length')
parser.add_argument('--epoch', type=int, default=1000, help='Total Episodes')
parser.add_argument('--alg', type=str, choices=['ppolag', 'fuzppolag', 'cppo', 'fuzcppo', 'cup', 'fuzcup'], default='ppolag', help='Algorithm to use')
parser.add_argument('--env', type=str, choices=['DoubleIntegrator', 'CartPole', 'QuadRotor', 'QuadRotor3D'], default='CartPole', help='Safety Control Gym System')
parser.add_argument('--task', type=str, choices=['stab', 'track'], default='stab', help='Safety Control Gym tasks')
parser.add_argument('--level', type=int, default=10, help='Fuzzy weight dimension')
parser.add_argument('--device_id', type=int, default=0, help='Random seed')

parser.add_argument('--cppo_beta', type=float, default=100.0, help='Penalty coefficient for CVaR constraint violation')
parser.add_argument('--cppo_nu_start', type=float, default=10, help='Initial threshold value for the CVaR constraint.')
parser.add_argument('--cup_nu', type=float, default=0.1, help='Upper bound for the Lagrange multiplier')
parser.add_argument('--cup_lambda', type=float, default=0.9, help='Trade-off parameter for GAE')
args = parser.parse_args()

if args.env == "DoubleIntegrator":
    env = DoubleIntegratorEnv(seed=args.seed)
elif args.env == "CartPole":
    import yaml
    from safe_control_gym.utils.registration import make
    yaml_config_path = f'./envs/safe_control_gym/config_overrides/cartpole/cartpole_{args.task}.yaml'
    with open(yaml_config_path, 'r') as file:
        yaml_config = yaml.safe_load(file)
    env = make('cartpole', **yaml_config['task_config'])
elif args.env == "QuadRotor":
    import yaml
    from safe_control_gym.utils.registration import make
    yaml_config_path = f'./envs/safe_control_gym/config_overrides/quadrotor/quadrotor_{args.task}.yaml'
    with open(yaml_config_path, 'r') as file:
        yaml_config = yaml.safe_load(file)
    env = make('quadrotor', **yaml_config['task_config'])
elif args.env == "QuadRotor3D":
    import yaml
    from safe_control_gym.utils.registration import make
    yaml_config_path = f'./config_overrides/quadrotor_3D/quadrotor_3D_{args.task}.yaml'
    with open(yaml_config_path, 'r') as file:
        yaml_config = yaml.safe_load(file)
    env = make('quadrotor', **yaml_config['task_config'])
else:
    raise Exception

def generate_unique_output_dir(base_dir, alg, seed):
    output_dir = f"{base_dir}/{alg}/{alg}_seed_{seed}_v0"
    counter = 1
    while os.path.exists(output_dir):
        output_dir = f"{base_dir}/{alg}/{alg}_seed_{seed}_v{counter}"
        counter += 1
    os.makedirs(output_dir)
    return output_dir

base_dir = f"./src/data/{args.env}-{args.task}"
unique_output_dir = generate_unique_output_dir(base_dir, args.alg, args.seed)
print(f"The unique output directory is: {unique_output_dir}")


alg_kwargs = {
    'ac_kwargs': dict(hidden_sizes=(64,64)),
    'seed': args.seed,
    'steps_per_epoch': 4 * args.len,
    'max_ep_len': args.len,
    'epochs': 1000,
    'device': 'cuda' + ":" + str(args.device_id),
    'logger_kwargs': dict(output_dir=unique_output_dir, exp_name=args.env)
}

if args.alg == "ppolag":
    ppolag(env_fn=lambda: env, **alg_kwargs)
elif args.alg == "fuzppolag":
    alg_kwargs.update({'level': args.level})  
    fuzppolag(env_fn=lambda: env, **alg_kwargs)

elif args.alg == "cppo":
    alg_kwargs.update({'beta': args.cppo_beta})  
    alg_kwargs.update({'nu_start': args.cppo_nu_start})  
    cppo(env_fn=lambda: env, **alg_kwargs)
elif args.alg == "fuzcppo":
    alg_kwargs.update({'beta': args.cppo_beta})  
    alg_kwargs.update({'nu_start': args.cppo_nu_start})  
    alg_kwargs.update({'level': args.level})
    fuzcppo(env_fn=lambda: env, **alg_kwargs)

elif args.alg == "cup":
    alg_kwargs.update({'cup_lambda': args.cup_lambda})  
    alg_kwargs.update({'cup_nu': args.cup_nu}) 
    alg_kwargs.update({'ac_kwargs':dict(hidden_sizes=[256, 128])})
    cup(env_fn=lambda: env, **alg_kwargs)
elif args.alg == "fuzcup":
    alg_kwargs.update({'ac_kwargs':dict(hidden_sizes=[256, 128])})
    alg_kwargs.update({'level': args.level})
    fuzcup(env_fn=lambda: env, **alg_kwargs)