import pyvirtualdisplay

_display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
_ = _display.start()

from FQE import FQEV_resize, initialize_zero
from FQE_utils import *
from NN import PolicyNet, QVRNet, net_param_num
from PG import PG
from PG_utils import CartPoleEnvR, set_seed
import gym
import torch
import numpy as np

env = CartPoleEnvR()
resize = T.Compose([T.ToPILImage(), T.Resize(20, interpolation=Image.CUBIC), T.ToTensor()])     # input 40 for 3x40x150; input 20 for 3x20x75
n_state = env.observation_space.shape[0]
n_action = env.action_space.n

policy_net = PolicyNet(n_state, n_action)
policy_net.load_state_dict(torch.load("target_policy_net.pickle"))
m = net_param_num(policy_net)

screen_shape = get_screen(env, resize).shape
screen_height, screen_width = screen_shape[1], screen_shape[2]

# set seed_init
seed_init = 2005
set_seed(env, seed_init)

K = 20000
H = 100
eps = 0.2           # this is the epsilon-greedy level of behavior policy
policy_net.eps = 0  # this is the epsilon-greedy level of target policy
D = StreamingDataset('data/IID_eps-' + str(eps) + '_K-' + str(K), H)

sample_s = 1000
max_epoch = 20
lr = 0.01
batch_size = 256
Q_training_param = {'max_epoch': max_epoch, 'lr': lr, 'batch_size': batch_size}
Q = QVRNet(screen_height, screen_width, env.action_space.n)
v_hat, Q = FQEV(env, D, policy_net, resize, screen_height, screen_width, Q_training_param, H, Q=Q, sample_s=sample_s)
print('The estimated policy value by FQE is', v_hat)
