from replay_buffer import ReplayBuffer
import h5py
import numpy as np
import torch
import pytorch_util
from sampler import StepSampler


def save_dataset(env, max_size, max_traj_length, path, device, policy=None):
    policy = policy if policy is not None else \
        pytorch_util.TanhGaussianPolicy(env.observation_space.shape[0], env.action_space.shape[0]).to(device)
    sampler_policy = pytorch_util.SamplerPolicy(policy,device)
    sample = StepSampler(env, max_traj_length)
    datas = sample.sample(sampler_policy, max_size)
    with h5py.File(path, "w") as f:
        for k, v in datas.items():
            print("k:", k, " v:", v)
            print('v type:', type(v))
            y = f.create_dataset(k,data=v)
    f.close()

def load_dataset(path):
    print('load')
    dataset = {}
    with h5py.File(path, 'r') as f:
        for k,_ in f.items():
            print("k:", k, " v:", f[k][:])
            print('v type:',type(f[k][:]))
            dataset[k] = f[k][:]
    return dataset


