import os
import pickle

import gym
import d4rl
import numpy as np
from loguru import logger

from offlinerl.utils.data import SampleBatch

import h5py

def get_data(dataset,data_path,isMediumExpert,data_proportion):
    N=len(dataset['terminals'])
    
    num = int(N/data_proportion)
    print("--------------------------")
    print("data num:",num)
    print("--------------------------")

    # print(111111)
    # print(isMediumExpert)
    if not isMediumExpert:
        print("---------------------")
        print("Not isMediumExpert.")
        print("---------------------")
        dataset['observations']=dataset['observations'][:num]
        dataset['next_observations'] = dataset['next_observations'][:num]
        dataset['actions'] = dataset['actions'][:num]
        dataset['rewards'] = dataset['rewards'][:num]
        dataset['terminals'] = dataset['terminals'][:num]
        # print(np.shape(dataset['observations']))
    else:
        num=num//2
        print("---------------------")
        print("IsMediumExpert.")
        print("---------------------")
        dataset['observations'] = np.concatenate((dataset['observations'][:num], dataset['observations'][-num:]), axis=0)
        dataset['next_observations'] = np.concatenate((dataset['next_observations'][:num], dataset['next_observations'][-num:]),
                                                 axis=0)
        dataset['actions'] = np.concatenate((dataset['actions'][:num], dataset['actions'][-num:]), axis=0)
        dataset['rewards'] = np.concatenate((dataset['rewards'][:num], dataset['rewards'][-num:]), axis=0)
        dataset['terminals'] = np.concatenate((dataset['terminals'][:num], dataset['terminals'][-num:]), axis=0)
        # print(222222)
        # print(np.shape(dataset['observations']))
    
    print("data num:",num)

    if data_path:
        data = h5py.File(data_path,'r')
        dataset['observations']=np.concatenate((dataset['observations'],np.array(data['observations'],dtype=np.float32)),axis=0)
        dataset['next_observations'] = np.concatenate((dataset['next_observations'], np.array(data['next_observations'],dtype=np.float32)), axis=0)
        dataset['actions'] = np.concatenate((dataset['actions'], np.array(data['actions'],dtype=np.float32)), axis=0)
        dataset['rewards'] = np.concatenate((dataset['rewards'], np.array(np.squeeze(data['rewards']),dtype=np.float32)), axis=0)
        dataset['terminals'] = np.concatenate((dataset['terminals'], np.array(data['terminals'],dtype=np.float32)), axis=0)
    return dataset

def load_d4rl_buffer(task, data_path, isMediumExpert, data_proportion):
    env = gym.make(task[5:])
    dataset = d4rl.qlearning_dataset(env)

    data_process = True
    if data_process:
        print("------------------")
        print("data process!")
        print("------------------")
        dataset = get_data(dataset, data_path, isMediumExpert, data_proportion)
    else:
        print("------------------")
        print("Not data process!")
        print("------------------")

    # print(np.shape(dataset['observations']))
    # input()

    buffer = SampleBatch(
        obs=dataset['observations'],
        obs_next=dataset['next_observations'],
        act=dataset['actions'],
        rew=np.expand_dims(np.squeeze(dataset['rewards']), 1),
        done=np.expand_dims(np.squeeze(dataset['terminals']), 1),
    )

    logger.info('obs shape: {}', buffer.obs.shape)
    logger.info('obs_next shape: {}', buffer.obs_next.shape)
    logger.info('act shape: {}', buffer.act.shape)
    logger.info('rew shape: {}', buffer.rew.shape)
    logger.info('done shape: {}', buffer.done.shape)
    logger.info('Episode reward: {}', buffer.rew.sum() /np.sum(buffer.done) )
    logger.info('Number of terminals on: {}', np.sum(buffer.done))
    return buffer
