import pickle
import time
import copy
import numpy as np
import os

import gym
import d4rl
from .factory import get_env
os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'



def training_set_construction(data_dict):
    
    assert len(list(data_dict.keys())) == 1
    data_dict = data_dict[list(data_dict.keys())[0]]
    states = data_dict['states']
    actions = data_dict['actions']
    rewards = data_dict['rewards']
    next_states = data_dict['next_states']
    terminations = data_dict['terminations']

    return [states, actions, rewards, next_states, terminations]




def load_dataset(env_name, dataset, cfg, seed):
    print(env_name, dataset)
    path = None
    if env_name == 'HalfCheetah':
        if dataset == 'expert':
            path = {"env": "halfcheetah-expert-v2"}
        elif dataset == 'medexp':
            path = {"env": "halfcheetah-medium-expert-v2"}
        elif dataset == 'medium':
            path = {"env": "halfcheetah-medium-v2"}
        elif dataset == 'medrep':
            path = {"env": "halfcheetah-medium-replay-v2"}
    elif env_name == 'Walker2d':
        if dataset == 'expert':
            path = {"env": "walker2d-expert-v2"}
        elif dataset == 'medexp':
            path = {"env": "walker2d-medium-expert-v2"}
        elif dataset == 'medium':
            path = {"env": "walker2d-medium-v2"}
        elif dataset == 'medrep':
            path = {"env": "walker2d-medium-replay-v2"}
    elif env_name == 'Hopper':
        if dataset == 'expert':
            path = {"env": "hopper-expert-v2"}
        elif dataset == 'medexp':
            path = {"env": "hopper-medium-expert-v2"}
        elif dataset == 'medium':
            path = {"env": "hopper-medium-v2"}
        elif dataset == 'medrep':
            path = {"env": "hopper-medium-replay-v2"}
    elif env_name == 'Ant':
        if dataset == 'expert':
            path = {"env": "ant-expert-v2"}
        elif dataset == 'medexp':
            path = {"env": "ant-medium-expert-v2"}
        elif dataset == 'medium':
            path = {"env": "ant-medium-v2"}
        elif dataset == 'medrep':
            path = {"env": "ant-medium-replay-v2"}

    elif env_name == 'SimEnv3':
        path = {"generate": None}

    assert path is not None

    datasets = {}
    for name in path:
        if name == "env":
            env = gym.make(path['env'])
            try:
                data = env.get_dataset()
            except:
                env = env.unwrapped
                data = env.get_dataset()
            datasets[name] = {
                'states': data['observations'],
                'actions': data['actions'],
                'rewards': data['rewards'],
                'next_states': data['next_observations'],
                'terminations': data['terminals'],
            }
        elif name == "generate":
            env = get_env(cfg.env, seed)
            datasets[name] = env.get_dataset()
        else:
            raise NotImplementedError
        return datasets
    else:
        return {}