import random
import copy
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt

import environment
import agents

dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



print("Goal-conditioned")
Hs = [(i+1)*5 for i in range(10)]
n_tests = 10
all_results = []
for H in Hs:
    print()
    print("Testing H = "+str(H))
    env = environment.ENV(H)
    results = []
    for i in range(n_tests):
        env = environment.ENV(H)
        agent = agents.Agent_MLP_GC_BC(env, [256, 256]).to(dev)
        agent.train_agent()
        results.append(env.eval(agent))
    print("results : ")
    print(results)
    print(sum(results)/n_tests)
    all_results.append(sum(results)/n_tests)
print(all_results)


print("Q-learning")
Hs = [(i+1)*5 for i in range(10)]
n_tests = 10
all_results = []
for H in Hs:
    print()
    print("Testing H = "+str(H))
    results = []
    for i in range(n_tests):
        env = environment.ENV(H)
        agent = agents.Agent_Q(env, [512, 512]).to(dev)
        agent.train_agent()
        results.append(env.eval(agent))
    print("results : ")
    print(results)
    print(sum(results)/n_tests)
    all_results.append(sum(results)/n_tests)
print(all_results)


print("PPO")
Hs = [(i+1)*5 for i in range(10)]
n_tests = 10
all_results = []
for H in Hs:
    print()
    print("Testing H = "+str(H))
    results = []
    for i in range(n_tests):
        env = environment.ENV(H)
        agent = agents.Agent_PPO(env, [512, 512]).to(dev)
        agent.train_agent()
        results.append(env.eval(agent))
    print("results : ")
    print(results)
    print(sum(results)/n_tests)
    all_results.append(sum(results)/n_tests)
print(all_results)
