import os
import pickle

import gym
import d4rl
import numpy as np
from loguru import logger

from offlinerl.utils.data import SampleBatch

import envs.build_envs.env_builder as env_builder

import h5py

def get_data(data_path):
    dataset = {}
    data = h5py.File(data_path, 'r')
    dataset['observations'] = np.array(data['observations'])
    dataset['next_observations'] = np.array(data['next_observations'])
    dataset['actions'] = np.array(data['actions'])
    dataset['rewards'] = np.array( np.squeeze(data['rewards']))
    dataset['terminals'] = np.array(data['terminals'])
    return dataset

def load_d4rl_buffer(data_path,):
    env = env_builder.build_terrains_env(enable_randomizer=True,
                                         enable_rendering=False,
                                         RECORD_VIDEO=False,
                                         max_step_num=1,
                                         env_id=1
                                         )

    # dataset = d4rl.qlearning_dataset(env)

    data_process = True
    if data_process:
        print("------------------")
        print("data process!")
        print("------------------")
        dataset = get_data(data_path)
    else:
        print("------------------")
        print("Not data process!")
        print("------------------")

    # print(np.shape(dataset['observations']))
    # input()

    buffer = SampleBatch(
        obs=dataset['observations'],
        obs_next=dataset['next_observations'],
        act=dataset['actions'],
        rew=np.expand_dims(np.squeeze(dataset['rewards']), 1),
        done=np.expand_dims(np.squeeze(dataset['terminals']), 1),
    )

    logger.info('obs shape: {}', buffer.obs.shape)
    logger.info('obs_next shape: {}', buffer.obs_next.shape)
    logger.info('act shape: {}', buffer.act.shape)
    logger.info('rew shape: {}', buffer.rew.shape)
    logger.info('done shape: {}', buffer.done.shape)
    logger.info('Episode reward: {}', buffer.rew.sum() /np.sum(buffer.done) )
    logger.info('Number of terminals on: {}', np.sum(buffer.done))
    return buffer
