from PG import PG
from SVRPG import SVRPG
from TSIVR_PG import TSIVR_PG
from SRVR_PG import SRVR_PG
from HSPGA import HSPGA
import torch
import numpy as np
import gym
import os


if __name__ == '__main__':
    episode = 50
    gamma = 0.99
    env_name = 'FrozenLake8x8-v0'

    alg_name = 'PG'
    learning_rate = 0.05

    #alg_name = 'TSIVR_PG'
    #learning_rate = 0.1

    #alg_name = 'SVRPG'
    #learning_rate = 0.05

    #alg_name = 'SRVR_PG'
    #learning_rate = 0.05

    #alg_name = 'HSPGA'
    #learning_rate = 0.08

    env = gym.make(env_name)
    n_state = 64
    n_action = env.action_space.n

    print(env_name, alg_name, str(learning_rate), str(episode))
    seeds = range(50)

    N = 100
    B = int(np.sqrt(N))
    m = int(np.sqrt(N))
    res = []
    for i in range(50):
        print('N=' + str(N) + ' i=' + str(i))
        torch.manual_seed(seeds[i])
        np.random.seed(seeds[i])
        env.seed(seeds[i])
        if alg_name == 'PG':
            res.append(PG(env, 2 * episode, learning_rate, gamma, N, n_state, n_action))
        elif alg_name == 'TSIVR_PG':
            res.append(TSIVR_PG(env, episode, learning_rate, gamma, N, B, m, n_state, n_action))
        elif alg_name == 'SVRPG':
            res.append(SVRPG(env, episode, learning_rate, gamma, N, int(N ** (2/3)), int(N ** (1/3)), n_state, n_action))
        elif alg_name == 'SRVR_PG':
            res.append(SRVR_PG(env, episode, learning_rate, gamma, N, B, m, n_state, n_action))
        elif alg_name == 'HSPGA':
            res.append(HSPGA(env, int(episode * 2/3), learning_rate, gamma, N, B, m, n_state, n_action))

        res = np.array(res)
        filepath = './results/' + alg_name + '/' + env_name + '_lr' + str(learning_rate).replace('.', '_') + '_eps_' + str(episode)
        filename = filepath + '/N=' + str(N)
        if not os.path.exists(filepath):
            os.makedirs(filepath)
        np.save(filename, res)
