import d3rlpy
import gym
from IPython import embed
from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.models.encoders import VectorEncoderFactory
from d3rlpy.base import _serialize_params
import argparse
import os
import json
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--version", type=str)
parser.parse_args()
res = parser.parse_args()

version = res.version
assert version == 'cp' or version == 'mc'


ds = [10, 50, 1000, 5000, 25000, 50000]
Ts = [0, 1, 2, 3, 4]

if version == 'cp':
    ns = [500, 700, 900, 1100, 1300, 1500]
elif version == 'mc':
    Ts = [0, 1, 2, 3, 4]
    ds = [10, 50, 100, 1000, 5000, 25000, 50000]
    ns = [200, 400, 600, 800, 1000]


values = np.zeros((len(Ts), len(ds), len(ns)))



for k, T in enumerate(Ts):
    for i, d in enumerate(ds):
        for j, n in enumerate(ns):

            if version == 'cp':
                path = 'models/models_cp/'
                filename = 'dqn_d{d}_n{n}_T{T}'.format(d=d,n=n,T=T)
            elif version == 'mc':
                path = 'models/models_mc/'
                filename = 'dqn_d{d}_n{n}_T{T}'.format(d=d,n=n,T=T)

            filepath_json = os.path.join(path, filename + '.json')
            filepath_model = os.path.join(path, filename + '.pt')

            if not os.path.exists(filepath_json):
                print("Couldn't find " + filepath_json)
                raise Exception()

for k, T in enumerate(Ts):
    for i, d in enumerate(ds):
        for j, n in enumerate(ns):

            if version == 'cp':
                path = 'models/models_cp/'
                filename = 'dqn_d{d}_n{n}_T{T}'.format(d=d,n=n,T=T)
            elif version == 'mc':
                path = 'models/models_mc/'
                filename = 'dqn_d{d}_n{n}_T{T}'.format(d=d,n=n,T=T)

            filepath_json = os.path.join(path, filename + '.json')
            filepath_model = os.path.join(path, filename + '.pt')
            if not os.path.exists(path):
                os.makedirs(path)

            if version == 'cp':
                dataset, env = get_cartpole()
                dqn = DQN()
                dqn = DQN.from_json(filepath_json)
                dqn.load_model(filepath_model)
            elif version == 'mc':
                env = gym.make('MountainCar-v0')
                dqn = DQN()
                dqn = DQN.from_json(filepath_json)
                dqn.load_model(filepath_model)

            res = evaluate_on_environment(env, n_trials=32, render=False)(dqn)
            print("\nAverage value of d: " + str(d) + " n: " + str(n))
            print(res)

            values[k, i, j] = res


if version == 'cp':
    filename = 'data/cp_values.npy'
elif version == 'mc':
    filename = 'data/mc_values.npy'

if not os.path.exists('data/'):
    os.makedirs('data/')
np.save(filename, values)

print("Saving to " + str(filename))
embed()
