import d3rlpy
from IPython import embed
from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.models.encoders import VectorEncoderFactory
from d3rlpy.base import _serialize_params, ImplBase, LearnableBase
import argparse
import os
import json
from sklearn.model_selection import train_test_split

parser = argparse.ArgumentParser()
parser.add_argument("--d", type=int)
parser.add_argument("--n", type=int)
parser.add_argument("--T", type=int)
parser.parse_args()
res = parser.parse_args()

d = res.d # dimension of middle layer
n = res.n # number of episodes
T = res.T
assert res.d is not None
assert res.n is not None
assert res.T is not None

path = 'models/models_cp/'
filename = 'dqn_d{d}_n{n}_T{T}'.format(d=d,n=n,T=T)
filepath_json = os.path.join(path, filename + '.json')
filepath_model = os.path.join(path, filename + '.pt')
if not os.path.exists(path):
    os.makedirs(path)




def serialize(old_params, algo):
    params = {}
    for k, v in old_params.items():
        if isinstance(v, (ImplBase, LearnableBase)):
            continue
        params[k] = v

    # save algorithm name
    params["algorithm"] = algo.__class__.__name__

    # save shapes
    params["observation_shape"] = algo._impl.observation_shape
    params["action_size"] = algo._impl.action_size

    # serialize objects
    params = _serialize_params(params)
    return params




# obtain dataset
dataset, env = get_cartpole()
n_cut = n
if n > len(dataset.episodes):
    n_cut = len(dataset.episodes)
episodes = dataset[:n_cut]
n_train = int(.8 * n_cut)
train_eps, test_eps = episodes[:n_train], episodes[n_train:]


# setup algorithm
encoder_factory = VectorEncoderFactory(hidden_units=[d], activation='relu')
dqn = DQN(encoder_factory=encoder_factory, use_gpu=True)


# train
dqn.fit(train_eps, n_epochs=20)

params = dqn.get_params()
params_serial = serialize(params, dqn)
with open(filepath_json, 'w') as outfile:
    json.dump(params_serial, outfile)




# save
dqn.save_model(filepath_model)



# evaluate trained algorithm
res = evaluate_on_environment(env, n_trials=32, render=False)(dqn)
print("\nAverage value of DQN d: " + str(d) + " n: " + str(n))
print(res)


