import numpy as np

from envs.d4rl_env import Maze2d
from d4rl.pointmaze.waypoint_controller import WaypointController
from d4rl.pointmaze.maze_model import U_MAZE
import argparse
import joblib
import os
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--p', type=float, default=0)
parser.add_argument('--data_dir', type=str, default='./')
args = parser.parse_args()


def set_waypoint(position, target):
    if position[0] < 1.25 and position[1] < 2.25 and target[0] > 1.25:
        waypoint = np.array([1., 3.], dtype=np.float32) + np.random.uniform(-0.1, 0.1, size=2)
    elif position[0] > 2.25 and position[1] < 2.25 and target[0] < 2.25:
        waypoint = np.array([3., 3.], dtype=np.float32) + np.random.uniform(-0.1, 0.1, size=2)
    elif target[1] < 2.25:
        if position[0] >= 1.25 and target[0] < 1.25:
            waypoint = np.array([1., 3.], dtype=np.float32) + np.random.uniform(-0.1, 0.1, size=2)
        elif position[0] <= 2.25 and target[0] > 2.25:
            waypoint = np.array([3., 3.], dtype=np.float32) + np.random.uniform(-0.1, 0.1, size=2)
        else:
            waypoint = target
    else:
        waypoint = target
    return waypoint


env = Maze2d('maze2d-umaze-v1', p=args.p)
controller = WaypointController(U_MAZE)

n_data = 1000000
observations = np.zeros([n_data, 4], dtype=np.float32)
actions = np.zeros([n_data, 2], dtype=np.float32)
timeouts = np.zeros([n_data], dtype=np.float32)
terminals = np.zeros([n_data], dtype=np.float32)
rewards = np.zeros([n_data], dtype=np.float32)

done = True
for n in tqdm(range(n_data)):
    if done:
        obs = env.reset()
        target = env.sample_target()
        done = False
        timestep = 0
    waypoint = set_waypoint(obs[:2], target)
    act, controller_done = controller.get_action(obs[:2], obs[2:], waypoint)
    act = np.clip(act + np.random.randn(2), -1, 1)

    observations[n] = obs

    actions[n] = act
    obs, rew, done, _ = env.step(act)

    timeouts[n] = done
    rewards[n] = rew

data = {
    'observations': observations,
    'actions': actions,
    'rewards': rewards,
    'timeouts': timeouts,
    'terminals': terminals
}
if os.path.exists(os.path.join(args.data_dir, 'maze2d-umaze-v1.pkl')):
    dataset = joblib.load(os.path.join(args.data_dir, 'maze2d-umaze-v1.pkl'))
else:
    dataset = dict()
dataset[f'p{args.p:.1f}'] = data
joblib.dump(dataset, os.path.join(args.data_dir, 'maze2d-umaze-v1.pkl'))
