import torch
import os


def get_rl_dataset(path, dataset,
                   quality='medium', type='1', noise='0.5'):
    fn = os.path.join(path, dataset, f'{quality}-type_{type}-noise_{noise}.pt')
    ckpt = torch.load(fn)
    states, actions, next_states = ckpt['states'], ckpt['actions'], ckpt['next_states']
    X = torch.cat((states, actions), dim=1).cpu().numpy()
    y = next_states.cpu().numpy()
    return X, y
