from PG import PG
from SVRPG import SVRPG
from TSIVR_PG import TSIVR_PG
from SRVR_PG import SRVR_PG
from HSPGA import HSPGA
import torch
import numpy as np
import gym
import os


if __name__ == '__main__':
    eps = 50
    gamma = 0.99

    env_name = 'CartPole-v0'

    alg_name = 'PG'
    learning_rate = 0.005

    # alg_name = 'TSIVR_PG'
    # learning_rate = 0.01

    # alg_name = 'SVRPG'
    # learning_rate = 0.005

    # alg_name = 'SRVR_PG'
    # learning_rate = 0.005

    # alg_name = 'HSPGA'
    # learning_rate = 0.008

    env = gym.make(env_name)
    n_state = env.observation_space.shape[0]
    n_action = env.action_space.n

    print(env_name, alg_name, str(learning_rate), str(eps))

    seeds = range(50)
    N = 25
    B = int(np.sqrt(N))
    m = int(np.sqrt(N))
    res = []
    for i in range(50):
        print('N=' + str(N) + ' i=' + str(i))
        torch.manual_seed(seeds[i])
        np.random.seed(seeds[i])
        env.seed(seeds[i])
        if alg_name == 'PG':
            res.append(PG(env, 2*eps, learning_rate, gamma, N, n_state, n_action))
        elif alg_name == 'TSIVR_PG':
            res.append(TSIVR_PG(env, eps, learning_rate, gamma, N, B, m, n_state, n_action))
        elif alg_name == 'SVRPG':
            res.append(SVRPG(env, eps, learning_rate, gamma, N, int(N**(2/3)), int(N**(1/3)), n_state, n_action))
        elif alg_name == 'SRVR_PG':
            res.append(SRVR_PG(env, eps, learning_rate, gamma, N, B, m, n_state, n_action))
        elif alg_name == 'HSPGA':
            res.append(HSPGA(env, eps, learning_rate, gamma, N, B, m, n_state, n_action))

    res = np.array(res)
    filepath = './results/' + alg_name + '/' + env_name + '_lr' + str(learning_rate).replace('.', '_') + '_eps_' + str(eps)
    filename = filepath + '/N=' + str(N)
    if not os.path.exists(filepath):
        os.makedirs(filepath)
    np.save(filename, res)
