import numpy as np
import os
import torch
import gym
import gym_franka
from gym_franka.ios_controller import Controller
from _thread import start_new_thread
import matplotlib.pyplot as plt


NUM_DEMOS = 5
ROOT_FOLDER = './demo/franka_drawer/'
target_folder = ROOT_FOLDER + str(NUM_DEMOS)
if not os.path.isdir(target_folder):
    os.makedirs(target_folder)

env = gym.make('FrankaDrawer-v1')

obs_list = []
next_obs_list = []
action_list = []
reward_list = []
not_done_list = []

demo_starts = []
demo_ends = []

controller = Controller('192.168.0.1', 6789)
start_new_thread(controller.spin, ())

i = 0
while i < NUM_DEMOS:
    img_obs = env.reset()
    controller.reset()
    demo_starts.append(len(obs_list))
    while True:
        action = controller.get_action()
        next_img_obs, r, d, info = env.step(action)
        obs_list.append(img_obs)
        next_obs_list.append(next_img_obs)
        action_list.append(action)
        assert r == -1 or r == 100
        reward_list.append([r])
        not_done_list.append([not d])
        img_obs = next_img_obs

        plt.imsave(target_folder + f'/{len(obs_list):04d}.png', next_img_obs[3:].transpose((1, 2, 0)))

        if d:
            if r <= 0:
                demo_starts = demo_starts[:-1]
                obs_list = obs_list[:demo_ends[-1]]
                next_obs_list = next_obs_list[:demo_ends[-1]]
                action_list = action_list[:demo_ends[-1]]
                reward_list = reward_list[:demo_ends[-1]]
                not_done_list = not_done_list[:demo_ends[-1]]
                print('redo!')
            else:
                demo_ends.append(len(obs_list))
                print(i)
                i = i + 1
            break

payload = [np.array(obs_list), np.array(next_obs_list), np.array(action_list),
           np.array(reward_list), np.array(not_done_list)]
torch.save(payload, target_folder + '/0_' + str(len(obs_list)) + '.pt')
np.save(target_folder + '/demo_starts.npy', np.array(demo_starts))
np.save(target_folder + '/demo_ends.npy', np.array(demo_ends))
