"""
Baseline generator with random policy.

"""

import gym
import numpy as np
import os
import sys
import csv

def obs2state(obs, status):
    return np.hstack((obs, np.array([status])))

# Make dirs if not exist
if os.path.exists("./dataset_train") != True:
    os.mkdir("./dataset_train")
#
# if os.path.exists("./dataset_val") != True:
#     os.mkdir("./dataset_val")


# Make env
env = gym.make("CartPole-v1")

# Record baseline trajectory for train
f = open("./dataset_train/base.csv", "w")
writer = csv.writer(f, lineterminator="\n")

n_episodes = 1000

for i in range(n_episodes):
    obs = env.reset()
    writer.writerow(obs2state(obs, 0))
    done = False
    reward_sum = 0

    while not done:
        act = np.random.randint(2)
        obs, reward, done, _ = env.step(act)
        reward_sum += reward
        if done:
            writer.writerow(obs2state(obs, 1))
        else:
            writer.writerow(obs2state(obs, 0))

    print("Reward sum: ", reward_sum)

# # Record baseline trajectory for validation
# f = open("./dataset_val/base.csv", "w")
# writer = csv.writer(f, lineterminator="\n")
#
# for i in range(n_episodes):
#     obs = env.reset()
#     writer.writerow(obs2state(obs, 0))
#     done = False
#     reward_sum = 0
#
#     while not done:
#         act = np.random.randint(2)
#         obs, reward, done, _ = env.step(act)
#         reward_sum += reward
#         if done:
#             writer.writerow(obs2state(obs, 1))
#         else:
#             writer.writerow(obs2state(obs, 0))
#
#     print("Reward sum: ", reward_sum)
