import os
from time import time

import numpy as np
import torch

def get_batch(split, device, env_list):
    context_length = 144
    cnt = 90000 if split == 'train' else 10000
    assert cnt - context_length >= 0
    seed = int(f'{time():.10f}'[-9:][::-1])
    generator = np.random.RandomState(seed)
    enc = []; x_action = []; x_reward = []

    for iii, env_name_ in enumerate(env_list):
        data_dir = '../data/{}'.format(env_name)
        enc_path = os.path.join(data_dir, "encodings_{}.npy".format(split))
        enc_data = np.memmap(enc_path, dtype=np.float32, mode='r', shape=(100000, 512))
        act_path = os.path.join(data_dir, "actions_{}.npy".format(split))
        act_data = np.memmap(act_path, dtype=np.uint8, mode='r', shape=(100000,))
        reward_path = os.path.join(data_dir, "rewards_{}.npy".format(split))
        reward_data = np.memmap(reward_path, dtype=np.float32, mode='r', shape=(100000,))
        done_path = os.path.join(data_dir, "dones_{}.npy".format(split))
        done_data = np.memmap(done_path, dtype=bool, mode='r', shape=(100000,))

        start = generator.randint(cnt - context_length)
        e = np.array(enc_data[start: start + context_length + 1])
        act = np.array(act_data[start + 1: start + context_length + 1])
        reward = np.array(reward_data[start + 1: start + context_length + 1])
        reward = np.sign(reward)
        done = np.array(done_data[start + 1: start + context_length + 1])
        reward += 1
        reward[np.where(done)] += 3
        enc.append(e); x_action.append(act); x_reward.append(reward)

    enc = torch.tensor(np.array(enc), device=device).float()
    x_enc = enc[:, :-1, :]
    y_enc = enc[:, 1:, :]
    x_action = torch.tensor(np.array(x_action).astype(np.int64), device=device)
    x_reward = torch.tensor(np.array(x_reward).astype(np.int64), device=device)

    return x_enc, x_action, x_reward, y_enc
