"""
Just train the model and save the observations

Find simulation outputs in next script "simulation.py"
"""


import gym
import numpy as np 
from model import DQN_Agent
from tqdm import tqdm


env = gym.make('CartPole-v1')
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
exp_replay_size = 256
agent = DQN_Agent(seed=1423, layer_sizes=[input_dim, 64, output_dim], lr=1e-3, sync_freq=5, exp_replay_size=exp_replay_size)

# Main training loop
losses_list, reward_list, episode_len_list, epsilon_list = [], [], [], []
episodes = 10000
epsilon = 1





# initiliaze experiance replay
index = 0
for i in range(exp_replay_size):
    obs = env.reset()
    done = False
    while not done:

        A, _, _ = agent.get_action(obs, env.action_space.n, epsilon=1)
        obs_next, reward, done, _ = env.step(A.item())
        agent.collect_experience([obs, A.item(), reward, obs_next])
        obs = obs_next
        index += 1
        if index > exp_replay_size:
            break


obs_train = list()
index = 128
for i in tqdm(range(episodes)):
    ep_obs = list()
    obs, done, losses, ep_len, rew = env.reset(), False, 0, 0, 0
    while not done:
        ep_len += 1

        ep_obs.append(obs.tolist())

        A, _, _ = agent.get_action(obs, env.action_space.n, epsilon)
        obs_next, reward, done, _ = env.step(A.item())
        agent.collect_experience([obs, A.item(), reward, obs_next])

        obs = obs_next
        rew += reward
        index += 1

        if index > 128:
            index = 0
            for j in range(4):
                loss = agent.train(batch_size=16)
                losses += loss
    if epsilon > 0.05:
        epsilon -= (1 / 5000)

    losses_list.append(losses / ep_len), reward_list.append(rew)
    episode_len_list.append(ep_len), epsilon_list.append(epsilon)

    if i % 100 == 0:
        print(ep_len)

    obs_train.append(ep_obs)


print("Saving trained model")
agent.save_trained_model("weights/cartpole-dqn.pth")


obs_train = np.array(obs_train)
np.save('data/obs_train.npy', obs_train)
print("Num instances produced:", len(obs_train))





