import numpy as np
import gym
import random
from IPython import embed
from mc.utils import _discretize_state
import pickle
import d3rlpy

""" 
For the MountainCar behavior policy, we trained a policy using SARSA
from the following open-source implementation that is publicly available.
https://github.com/greatwallet/mountain-car
"""


min_state_val = 0
max_state_val = 40


pickle_name = 'mc/pickles/good.pickle'

envname = "MountainCar-v0"
env = gym.make(envname)

H = 201

with open(pickle_name, 'rb') as f:
    Q = pickle.load(f)

def ep():
    xs, As, xprimes, rs, dones = [], [], [], [], []
    obs = env.reset()
    done = False
    episode_reward_default = 0

    for h in range(H):
        if done:
            break
        obs2 = obs
        xs.append(obs2)
        next_s = obs
        next_s = _discretize_state(env, next_s, min_state_val, max_state_val)
        

        a = np.argmax(Q[next_s])
        if random.random() < .3:
            a = np.random.choice([0,1,2])


        xprime, r, done, _ = env.step(a)
        episode_reward_default += r

        if envname == 'MountainCar-v0':
            if done and h < 199:
                r = 250.0 + episode_reward_default
            else:
                r = 5*abs(xprime[0] - obs[0]) + 3*abs(obs[1])

        obs = xprime
        As.append(a)
        xprimes.append(obs)
        rs.append(r)
        dones.append(done)

    return xs, As, xprimes, rs, dones


def save_data(T, label=''):
    xs_trajs = []
    As_trajs = []
    xprimes_trajs = []
    rs_trajs = []
    dones_trajs = []

    for t in range(T):
        xs, As, xprimes, rs, dones = ep()
        xs_trajs += xs
        As_trajs += As
        xprimes_trajs += xprimes
        rs_trajs += rs
        dones_trajs += dones

    print("dataset size: " + str(len(xs_trajs)))
    return np.array(xs_trajs), np.array(As_trajs), np.array(rs_trajs), np.array(dones_trajs).astype(int)

xs, As, rs, dones = save_data(1000)

dataset = d3rlpy.dataset.MDPDataset(
    observations=xs,
    actions=As,
    rewards=rs,
    terminals=dones,
)
dataset.dump("data/mc_expert.h5")
embed()

