import os
import pickle

import gym
import d4rl
import numpy as np
from loguru import logger

from offlinerl.utils.data import SampleBatch

import h5py

def get_data(dataset,data_path,isMediumExpert):
    num = int(1e5)
    # print(111111)
    # print(isMediumExpert)
    if not isMediumExpert:
        print("---------------------")
        print("Not isMediumExpert.")
        print("---------------------")
        dataset['observations']=dataset['observations'][:num]
        dataset['next_observations'] = dataset['next_observations'][:num]
        dataset['actions'] = dataset['actions'][:num]
        dataset['rewards'] = dataset['rewards'][:num]
        dataset['terminals'] = dataset['terminals'][:num]
        # print(np.shape(dataset['observations']))
    else:
        print("---------------------")
        print("IsMediumExpert.")
        print("---------------------")
        dataset['observations'] = np.concatenate((dataset['observations'][:num], dataset['observations'][-num:]), axis=0)
        dataset['next_observations'] = np.concatenate((dataset['next_observations'][:num], dataset['next_observations'][-num:]),
                                                 axis=0)
        dataset['actions'] = np.concatenate((dataset['actions'][:num], dataset['actions'][-num:]), axis=0)
        dataset['rewards'] = np.concatenate((dataset['rewards'][:num], dataset['rewards'][-num:]), axis=0)
        dataset['terminals'] = np.concatenate((dataset['terminals'][:num], dataset['terminals'][-num:]), axis=0)
        # print(222222)
        # print(np.shape(dataset['observations']))

    dataset_target = dataset
    dataset_source = {}
    if data_path:
        data = h5py.File(data_path,'r')
      
        dataset_source['observations']=np.array(data['observations'],dtype=np.float32)
        dataset_source['next_observations'] = np.array(data['next_observations'],dtype=np.float32)
        dataset_source['actions'] = np.array(data['actions'],dtype=np.float32)
        dataset_source['rewards'] = np.array(np.squeeze(data['rewards']),dtype=np.float32)
        dataset_source['terminals'] = np.array(data['terminals'],dtype=np.float32)
    return (dataset_source,dataset_target)

def load_d4rl_buffer(task, data_path, isMediumExpert):
    env = gym.make(task[5:])
    dataset = d4rl.qlearning_dataset(env)

    data_process = True
    if data_process:
        print("------------------")
        print("data process!")
        print("------------------")
        dataset_source,dataset_target = get_data(dataset, data_path, isMediumExpert)
    else:
        print("------------------")
        print("Not data process!")
        print("------------------")

    buffer = SampleBatch(
        obs=dataset_source['observations'],
        obs_next=dataset_source['next_observations'],
        act=dataset_source['actions'],
        rew=np.expand_dims(np.squeeze(dataset_source['rewards']), 1),
        done=np.expand_dims(np.squeeze(dataset_source['terminals']), 1),
    )

    buffer_t = SampleBatch(
        obs=dataset_target['observations'],
        obs_next=dataset_target['next_observations'],
        act=dataset_target['actions'],
        rew=np.expand_dims(np.squeeze(dataset_target['rewards']), 1),
        done=np.expand_dims(np.squeeze(dataset_target['terminals']), 1),
    )

    logger.info('obs shape: {}', buffer.obs.shape)
    logger.info('obs_next shape: {}', buffer.obs_next.shape)
    logger.info('act shape: {}', buffer.act.shape)
    logger.info('rew shape: {}', buffer.rew.shape)
    logger.info('done shape: {}', buffer.done.shape)
    logger.info('Episode reward: {}', buffer.rew.sum() /np.sum(buffer.done) )
    logger.info('Number of terminals on: {}', np.sum(buffer.done))
    return (buffer,buffer_t)
