import gym
import torch
import os
import random
import math
import numpy as np
from itertools import count
import matplotlib.pyplot as plt
from collections import defaultdict
from sac.replay_memory import ReplayMemory, ReplayMemoryPER
from sac.sac import SAC
from model import EnsembleDynamicsModel
from predict_env import PredictEnv
from sample_env import EnvSampler
from tf_models.constructor import construct_model
from rf_env import rf_env_cont
import pandas as pd
from xgb_env import xgb_env_cont
import hyperopt as hp
import xgboost as xgb
from hyperopt import Trials, fmin
from sklearn.model_selection import train_test_split
import time
import sklearn as sk
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn.metrics import accuracy_score
import seaborn as sns
import ray
from ray import air, tune
from ray.air import session
from ray.tune import Stopper
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.search.bohb import TuneBOHB
from ray.tune.search.skopt import SkOptSearch
import ConfigSpace as CS
from ray.tune.search.hebo import HEBOSearch
import gc
import psutil
import os.path
from os import path
from scipy.stats import rankdata

class ArgumentReader:
    def __init__(self, env_name, data, dataname, seed, use_decay, gamma, tau, policy, target_update_interval,
                 automatic_entropy_tuning, hidden_size, lr, num_networks, num_elites, pred_hidden_size, reward_size,
                 replay_size, model_retain_epochs, model_train_freq, rollout_batch_size, epoch_length,
                 rollout_min_epoch, rollout_max_epoch, rollout_min_length, rollout_max_length, num_epoch, min_pool_size,
                 real_ratio, train_every_n_steps, num_train_repeat, policy_train_batch_size,
                 init_exploration_steps, max_path_length, model_type, cuda, emb_size, train_emb,
                 save_model, load_model, save_name, load_name, test_mode, act_noise, target_noise, noise_clip,
                 max_model_error, save_graphs, save_buffer, load_buffer, eval_freq, save_exploration, load_exploration,
                 emb_name, meta_training, eval_mode, action_norm, per, show_plots, epsilon, maxtime, save_raw_exp,
                 pretrain_dyn_model):

        self.env_name = env_name
        self.data = data
        self.dataname = dataname
        self.seed = seed
        self.use_decay = use_decay
        self.gamma = gamma
        self.tau = tau
        self.policy = policy
        self.target_update_interval = target_update_interval
        self.automatic_entropy_tuning = automatic_entropy_tuning
        self.hidden_size = hidden_size
        self.lr = lr
        self.num_networks = num_networks
        self.num_elites = num_elites
        self.pred_hidden_size = pred_hidden_size
        self.reward_size = reward_size
        self.replay_size = replay_size
        self.model_retain_epochs = model_retain_epochs
        self.model_train_freq = model_train_freq
        self.rollout_batch_size = rollout_batch_size
        self.epoch_length = epoch_length
        self.rollout_min_epoch = rollout_min_epoch
        self.rollout_max_epoch = rollout_max_epoch
        self.rollout_min_length = rollout_min_length
        self.rollout_max_length = rollout_max_length
        self.num_epoch= num_epoch
        self.min_pool_size = min_pool_size
        self.real_ratio = real_ratio
        self.train_every_n_steps = train_every_n_steps
        self.num_train_repeat = num_train_repeat
        self.policy_train_batch_size = policy_train_batch_size
        self. init_exploration_steps = init_exploration_steps
        self.max_path_length = max_path_length
        self.model_type = model_type
        self.cuda = cuda
        self.emb_size = emb_size
        self.train_emb = train_emb
        self.save_model = save_model
        self.load_model = load_model
        self.save_name = save_name
        self.load_name = load_name
        self.test_mode = test_mode
        self.act_noise = act_noise
        self.target_noise = target_noise
        self.noise_clip = noise_clip
        self.max_model_error = max_model_error
        self.save_graphs = save_graphs
        self.save_buffer = save_buffer
        self.load_buffer = load_buffer
        self.eval_freq = eval_freq
        self.save_exploration = save_exploration
        self.load_exploration = load_exploration
        self.emb_name = emb_name
        self.meta_training = meta_training
        self.eval_mode = eval_mode
        self.action_norm = action_norm
        self.per = per
        self.show_plots = show_plots
        self.epsilon = epsilon
        self.maxtime = maxtime
        self.save_raw_exp = save_raw_exp
        self.pretrain_dyn_model = pretrain_dyn_model


def train(args, env_sampler, predict_env, agent, env_pool, model_pool, action_size):

    start = time.time()
    maxtime = int(args.maxtime * 60)
    expscores = []
    times = []

    epscores = []
    epnums = []
    total_step = 0
    max_step = int(args.epoch_length * args.num_epoch)
    best_test_reward = 0
    rollout_length = args.rollout_min_length#TODO
    #total_steps = args.epoch_length * args.num_epoch * len(env_sampler)
    total_steps = args.num_epoch
    noise_steps = total_steps * 0.8
    noise_decrease = args.act_noise #- 0.025
    noise_step_size = noise_decrease / noise_steps

    best_rewards = np.zeros(10)
    best_params = np.zeros((10, action_size))
    best_reward_epoch = []

    args.load_buffer = False
    trial_num = args.init_exploration_steps
    rewards = []

    if args.pretrain_dyn_model:
        pretrain_dynamics_model(args, agent, predict_env)
        return

    if args.load_model:
        print('loading model parameters')
        load_model(agent, args)

    if args.load_exploration:
        expname = 'models/exploration_{}_{}.npy'.format(args.env_name, args.dataname)
        print(expname)
        try:
            print('loading exploration buffer for dataset {}'.format(args.dataname))
            batch = np.load(expname, allow_pickle=True).tolist()
            env_pool.push_batch(batch)
            print('exploration of size {} added to replay buffer'.format(len(env_pool)))
        except:
            try:
                convert_raw_exp(args, env_pool, agent)

            except:
                print('no exploration buffer found for dataset {}'.format(args.dataname))
                print('filling buffer with some random samples for dataset {}'.format(args.dataname))
                best_rewards[0], best_params[0], rewards, times, expscores = exploration_before_start(args, env_sampler, env_pool, action_size,
                                                       args.dataname, agent, start)
    else:
        print('filling buffer with some random samples for dataset {}'.format(args.dataname))
        best_rewards[0], best_params[0], rewards, times, expscores = exploration_before_start(args, env_sampler, env_pool, action_size,
                                            args.dataname, agent, start)
        epscores.append(np.max(best_rewards))
        epnums.append(trial_num)


    if args.maxtime > 0:
        epscores = expscores
        epnums = times

    args.max_model_error = np.std(rewards)

    test_rewards = []
    test_actions = []
    model_errors = []

    end = False
    trained_model = False
    rolled_out = False
    for epoch_step in range(args.num_epoch):
        if epoch_step > 0:
            noise = agent.exploration_noise - noise_step_size
            noise = max(0.1, noise)
            agent.exploration_noise = noise

        if end:
            break
        start_step = total_step
        best = 0
        for i in count():
            cur_step = total_step - start_step

            if (cur_step > 0 or total_step == 0) and total_step % args.model_train_freq == 0 and args.real_ratio < 1.0\
                    and (total_step != max_step) and (len(env_pool)>=100):

                trained_model = train_predict_model(args, env_pool, predict_env)
                rolled_out = False

                new_rollout_length = set_rollout_length(args, epoch_step)
                if rollout_length != new_rollout_length:
                    rollout_length = new_rollout_length
                    model_pool = resize_model_pool(args, rollout_length, model_pool)

                model_errors = []

            if cur_step >= args.epoch_length:
                break

            cur_state, action, reward, next_state, done, raw_state = env_sampler.sample(agent)
            rewards.append(reward)
            args.max_model_error = np.std(rewards)

            if args.real_ratio < 1.0 and trained_model:
                model_reward, _, _ = predict_env.step(cur_state, action, deterministic=False)
                model_errors.append(abs(model_reward-reward))

            acc = reward
            print('step action: ' + str(action)) #TODO uncomment
            print('step reward: ' + str(acc))

            if acc > np.min(best_rewards):
                idx = np.argmin(best_rewards)
                best_rewards[idx] = acc
                best_params[idx] = action

            trial_num += 1

            if args.maxtime > 0:
                timed = time.time() - start
                if timed >= maxtime:
                    end = True
                    break
                epnums.append(timed)
                # timed = int((time.time() - start) / 5)
                # if (timed*5) >= maxtime:
                #     end = True
                #     epnums.append(int(maxtime/5))
                #     break
                # epnums.append(timed)

            else:
                epnums.append(trial_num)

            epscores.append(np.max(best_rewards))

            if acc > best:
                best = acc

            if done:
                env_sampler.current_state = None

            if not ((cur_step) % args.epoch_length == 0):
                env_pool.push((cur_state, action, reward, next_state, done, raw_state), reward)


            if len(env_pool) > args.min_pool_size:
                if args.real_ratio < 1.0 and trained_model:
                    model_error = np.mean(model_errors)
                    if (model_error <= args.max_model_error) and (len(model_errors) >= 5):
                        model = True
                        if not rolled_out:
                            rolled_out = rollout_model(args, predict_env, agent, model_pool, env_pool)
                    else:
                        model = False
                else:
                    model = False
                train_policy_repeats(args, total_step, env_pool, model_pool, agent, model)

            total_step += 1

            if ((cur_step+1) % args.epoch_length == 0) and ((total_step/args.epoch_length) % args.eval_freq == 0):

                if args.meta_training:
                    rewards_test = np.zeros(args.epoch_length-1)
                    reward_max = 0
                    actions = np.zeros(args.action_size)
                    env_pool.current_state = None
                    for i in range(args.epoch_length):
                        cur_state, action, reward, next_state, done, _ = env_sampler.sample(agent, eval_t=True)
                        if i > 0:
                            acc = reward
                            print('step {} test action: {}'.format(i, action))
                            print('step {} test reward: {}'.format(i, acc))
                            if acc > reward_max:
                                reward_max = acc
                            rewards_test[i-1] = acc
                            actions += action

                            if acc > np.min(best_rewards):
                                idx = np.argmin(best_rewards)
                                best_rewards[idx] = acc
                                best_params[idx] = action
                                print('added to best rewards list')

                    rewards_name = np.mean(rewards_test)
                    actions_name = actions / (args.epoch_length-1)

                    print('{} max reward: {}'.format(name, reward_max))
                    print('{} mean reward: {}'.format(name, rewards_name))
                    print('{} mean action: {}'.format(name, actions_name))
                    epoch_reward = (reward_max - np.std(rewards_test))
                    epoch_avg_action = actions_name

                    print('epoch ' + str(epoch_step) + ' test score: ' + str(epoch_reward))

                    if epoch_reward > best_test_reward:
                        best_test_reward = epoch_reward
                        if args.save_model:
                            save_models(agent, args)

                    test_rewards.append(epoch_reward)
                    test_actions.append(epoch_avg_action)

                    if args.show_plots:
                        plt.plot(agent.policy_losses, color='green')
                        plt.plot(agent.critic1_losses, color='red')
                        plt.plot(agent.critic2_losses, color='pink')
                        plt.plot(agent.qvals, color='blue')
                        plt.plot(agent.target_qvals, color='purple')
                        plt.show()

                        actions_array = np.asarray(test_actions)
                        for j in range(len(test_actions[0])):
                            plt.plot(actions_array[:, j])
                        plt.show()

                        plt.plot(test_rewards)
                        plt.show()

                    env_pool.current_state = None

                else:
                    best_reward_epoch.append(best)
                    if args.show_plots:
                        plt.plot(agent.policy_losses, color='green')
                        plt.plot(agent.critic1_losses, color='red')
                        plt.plot(agent.critic2_losses, color='pink')
                        plt.plot(agent.qvals, color='blue')
                        plt.plot(agent.target_qvals, color='purple')
                        plt.show()

                        plt.plot(best_reward_epoch)
                        plt.show()


    if args.save_graphs and args.meta_training==False:
        plt.plot(best_reward_epoch)
        plt.savefig('best_rewards_{}_{}.png'.format(args.env_name, args.save_name))
        if args.show_plots:
            plt.show()
        else:
            plt.clf()

    plt.plot(agent.policy_losses, color='green')
    plt.plot(agent.critic1_losses, color='red')
    plt.plot(agent.critic2_losses, color='pink')
    plt.plot(agent.qvals, color='blue')
    plt.plot(agent.target_qvals, color='purple')
    if args.save_graphs:
        plt.savefig('losses_{}_{}.png'.format(args.env_name, args.save_name))
    if args.show_plots:
        plt.show()
    else:
        plt.clf()
    #
    if args.meta_training:
        plt.plot(test_rewards)
        if args.save_graphs:
            plt.savefig('rewards_{}_{}.png'.format(args.env_name, args.save_name))
        if args.show_plots:
            plt.show()
        else:
            plt.clf()

    if args.meta_training:
        actions_array = np.asarray(test_actions)
        for j in range(len(test_actions[0])):
            plt.plot(actions_array[:, j])
        if args.save_graphs:
            plt.savefig('actions_{}_{}.png'.format(args.env_name, args.save_name))
        if args.show_plots:
            plt.show()
        else:
            plt.clf()


    if args.save_buffer:
        buffername = 'models/buffer_{}_{}'.format(args.env_name, args.save_name)
        np.save(buffername, env_pool.buffer)

    print('Best found accuracy: {}'.format(np.max(best_rewards)))


    return epscores, epnums

def convert_reward(reward):
    r1 = 10 * reward + 1
    acc = math.log(r1) / math.log(11)
    return acc

def save_models(agent, args):
    agent.save_model(args.env_name, args.save_name)

def save_dynamics_model(args, predict_env):
    dynamics_path = "models/dynamics_model_new_{}_{}".format(args.env_name, args.save_name)
    print('Saving dynamic model to {}'.format(dynamics_path))
    torch.save(predict_env.model.ensemble_model.state_dict(), dynamics_path)


def load_model(agent, args):
    actor_path = "models/sac_actor_{}_{}".format(args.env_name, args.load_name)
    critic_path = "models/sac_critic_{}_{}".format(args.env_name, args.load_name)

    agent.load_model(actor_path, critic_path)

def load_dynamics_model(args, pred_model):
    dynamics_path = "models/dynamics_model_{}".format(args.env_name)
    pred_model.ensemble_model.load_state_dict(torch.load(dynamics_path))
    return pred_model


def exploration_before_start(args, env_sampler, env_pool, action_size, name, agent, start):
    if args.save_raw_exp:
        raw_buffer = ReplayMemory(args.init_exploration_steps)
    times = []
    scores = []
    rewards = []
    max_acc = 0
    best_action = []
    for i in range(args.init_exploration_steps):
        #print(i)
        if i % args.epoch_length == 0:
            env_sampler.env.reset()
            cur_state = env_sampler.env.state
            state_emb = embed_states(agent, [cur_state])[0]
        else:
            state_emb = next_state_emb

        cur_state = env_sampler.env.state
        action = np.random.uniform(low=-1, high=1, size=action_size)
        _, reward, done = env_sampler.env.step(action)

        next_state = env_sampler.env.state
        next_state_emb = embed_states(agent, [next_state])[0]

        acc = reward
        if acc > max_acc:
            max_acc = acc
            best_action = action
        if i % args.epoch_length == 0:
            continue

        rewards.append(reward)
        if args.maxtime > 0:
            timed = time.time() - start
            scores.append(np.max(rewards))
            times.append(timed)

        env_pool.push((state_emb, action, reward, next_state_emb, done, cur_state), reward)
        if args.save_raw_exp:
            raw_buffer.push((cur_state, action, reward, next_state, done), reward)


    print('Best found action: {} with accuracy: {}'.format(best_action, max_acc))
    env_sampler.current_state = None

    if args.save_exploration:
        expname = 'models/exploration_{}_{}'.format(args.env_name, name)
        np.save(expname, env_pool.buffer)

    if args.save_raw_exp:
        expname = 'models/exploration_raw_{}_{}'.format(args.env_name, name)
        np.save(expname, raw_buffer.buffer)

    return max_acc, best_action, rewards, times, scores

def convert_raw_exp(args, env_pool, agent):
    expname = 'models/exploration_raw_{}_{}.npy'.format(args.env_name, args.dataname)
    raw_exp = np.load(expname, allow_pickle=True)
    states = raw_exp[:, 0]
    next_states = raw_exp[:, 3]
    states_emb = embed_states(agent, states)
    next_states_emb = embed_states(agent, next_states)
    converted_batch = []
    for i in range(len(states)):
        converted_batch.append([states_emb[i], raw_exp[i, 1], raw_exp[i, 2], next_states_emb[i],
                                          raw_exp[i, 4], states[i]])
    env_pool.push_batch(converted_batch)

    if args.save_exploration:
        expname = 'models/exploration_{}_{}'.format(args.env_name, name)
        np.save(expname, env_pool.buffer)


def set_rollout_length(args, epoch_step):
    rollout_length = (min(max(args.rollout_min_length + (epoch_step+1 - args.rollout_min_epoch)
                              / (args.rollout_max_epoch - args.rollout_min_epoch) * (args.rollout_max_length - args.rollout_min_length),
                              args.rollout_min_length), args.rollout_max_length))
    return int(rollout_length)

def pretrain_dynamics_model(args, agent, predict_env):
    print('pretraining dynamics model')
    transferdata = load_buffer_func(args)
    transferbuffer = ReplayMemory(100000)
    for interaction in transferdata:
        if interaction == None:
            continue
        else:
            transferbuffer.push(interaction)
    state, action, reward, next_state, done = transferbuffer.sample(len(transferbuffer))
    action = np.stack(action)
    state = np.stack(state)
    inputs = np.concatenate((state, action), axis=-1)
    labels = reward

    predict_env.model.train(inputs, labels, batch_size=256, holdout_ratio=0.2,
                            max_epochs_since_update=20)
    save_dynamics_model(args, predict_env)


def train_predict_model(args, env_pool, predict_env):
    if args.per:
        state, action, reward, next_state, done, _, _ = env_pool.sample(len(env_pool))
    else:
        state, action, reward, next_state, done = env_pool.sample(len(env_pool))
        state = np.stack(state)

    action = np.stack(action)
    inputs = np.concatenate((state, action), axis=-1)
    labels = reward

    predict_env.model.train(inputs, labels, batch_size=32, holdout_ratio=0.2,
                            max_epochs_since_update=5)

    return True

def resize_model_pool(args, rollout_length, model_pool):
    rollouts_per_epoch = args.rollout_batch_size * args.epoch_length / args.model_train_freq
    model_steps_per_epoch = int(rollout_length * rollouts_per_epoch)
    new_pool_size = args.model_retain_epochs * model_steps_per_epoch

    sample_all = model_pool.return_all()
    new_model_pool = ReplayMemory(new_pool_size)
    new_model_pool.push_batch(sample_all)

    return new_model_pool


def rollout_model(args, predict_env, agent, model_pool, env_pool,):
    state, action, reward, next_states, done, raw_state = env_pool.sample_all_batch(args.rollout_batch_size)
    state = np.stack(state)
    next_states_model = []

    states_emb_tens = torch.FloatTensor(state)
    action = agent.select_action(states_emb_tens, noise=agent.exploration_noise)

    rewards, terminals, info = predict_env.step(state, action, deterministic=False)

    for i in range(args.rollout_batch_size):
        action_reward_i = np.concatenate((action[i], [rewards[i]]))
        next_state_i = np.vstack((raw_state[i], action_reward_i))
        next_states_model.append(next_state_i)

    next_states = embed_states(agent, next_states_model)

    model_pool.push_batch([(state[k], action[k], rewards[k], next_states[k], terminals[k], 0)
                           for k in range(state.shape[0])])

    return True


def embed_states(agent, states):
    with torch.no_grad():
        states_emb = []
        state_dict = defaultdict(list)

        for state in states:
            state_dict[state.data.shape].append(state)

        for key in state_dict.keys():
            batch = np.stack(state_dict[key])
            batch_tensor = torch.FloatTensor(batch)
            batch_emb = agent.embedder(batch_tensor)
            states_emb.append(batch_emb)

        states_emb = np.concatenate(states_emb)

        return states_emb #TODO return tensor


def train_policy_repeats(args, total_step, env_pool, model_pool, agent, use_model):
    if (total_step+1) % args.train_every_n_steps > 0:
        return

    if not use_model:
        train_amount = 1
        ratio = 1
    else:
        print('Using dynamics model')
        train_amount = int(1 / args.real_ratio)
        ratio = args.real_ratio


    for i in range(train_amount):

        env_batch_size = int(args.policy_train_batch_size * ratio)
        model_batch_size = int(args.policy_train_batch_size - env_batch_size)

        if args.per:
            env_state, env_action, env_reward, env_next_state, env_done, env_idx, env_weight = env_pool.sample(env_batch_size)
        else:
            env_state, env_action, env_reward, env_next_state, env_done = env_pool.sample(env_batch_size)
            env_action = np.stack(env_action)
            env_idx = [None] * len(env_state)
            env_weight = [1] * len(env_state)


        if model_batch_size > 0 and len(model_pool) > 0:
            model_state, model_action, model_reward, model_next_state, model_done, _ = model_pool.sample_all_batch(model_batch_size)
            model_idx = [None] * len(model_state)
            model_weight = [1] * len(model_state)
            env_state_obj = np.empty((len(env_state),), dtype=object)
            env_next_state_obj = np.empty((len(env_state),), dtype=object)
            for j in range(len(env_state)):
                env_state_obj[j] = env_state[j]
                env_next_state_obj[j] = env_next_state[j]


            batch_state, batch_action, batch_reward, batch_next_state, batch_done, batch_idx, batch_weight = (np.concatenate((env_state_obj, model_state), axis=0),
                                                                   np.concatenate((env_action, model_action.tolist()), axis=0),
                                                                   np.concatenate(
                                                                       (env_reward,
                                                                        model_reward), axis=0),
                                                                   np.concatenate((env_next_state_obj, model_next_state),
                                                                                    axis=0),
                                                                   np.concatenate((env_done,
                                                                                   model_done), axis=0),
                                                                   np.concatenate((env_idx,
                                                                                   model_idx), axis=0),
                                                                   np.concatenate((env_weight,
                                                                                   model_weight), axis=0),
                                                                   )
            updates = i+1
        else:
            batch_state, batch_action, batch_reward, batch_next_state, batch_done, batch_idx, batch_weight = (env_state, env_action,
                                                                                            env_reward, env_next_state, env_done,
                                                                                            env_idx, env_weight)
            updates = int((total_step+1) / args.train_every_n_steps)

        batch_reward, batch_done = np.squeeze(batch_reward), np.squeeze(batch_done)
        batch_done = batch_done.astype(int)
        agent.update_parameters((batch_state, batch_action, batch_reward, batch_next_state, batch_done, batch_idx, batch_weight),
                                updates, env_pool)

    return

def load_dataset(path):
    extension = path[-4:]
    if extension in ['data', '.csv', '.txt']:
        data = pd.read_csv(path, na_values=['?']).dropna()
        data_array = np.array(data)

    elif extension == 'xlsx':
        data = pd.read_excel(path, na_values=['?']).dropna()
        data_array = np.array(data)

    else:
        print('data not readable')
        return None

    x = data_array[:, :-1]
    y = data_array[:, -1]
    _, y = np.unique(y, return_inverse=True)

    for i in range(len(x[0])):
        if type(x[0, i]) == str:
            _, x_cat = np.unique(x[:, i], return_inverse=True)
            x[:, i] = x_cat
    return x, y

def get_data_paths(datafolder='data/'):
    datapaths = {}

    # Iterate directory
    for path in os.listdir(datafolder):
        subdir = os.path.join(datafolder, path)
        subdir_name = path
        for file in os.listdir(subdir):
            filepath = os.path.join(subdir, file)
            if os.path.isfile(filepath):
                datapaths[subdir_name] = filepath

    return datapaths

def load_buffer_func(args, path='models/buffers/'):
    bufferfiles = []
    for i in os.listdir(path):
        if os.path.isfile(os.path.join(path, i)) and 'buffer' in i and args.env_name in i:
            bufferfiles.append(os.path.join(path, i))

    buffer = []
    for filename in bufferfiles:
        print(filename)
        file = np.load(filename, allow_pickle=True).tolist()
        buffer += file

    return buffer

def hypopt_TPE(data, dataname, iterations=100, mode='TPE', algo='xgb', maxtime=0):
    x, y = data
    maxtime = maxtime * 60

    ### Step 1 : defining the objective function
    def objective_xgb(params):

        model = xgb.XGBClassifier(max_depth=params['max_depth'], learning_rate=params['learning_rate'],
                                  n_estimators=params['n_estimators'], gamma=params['gamma'],
                                  min_child_weight=params['min_child_weight'], subsample=params['subsample'],
                                  colsample_bytree=params['colsample_bytree'],
                                  colsample_bylevel=params['colsample_bylevel'],
                                  reg_alpha=params['reg_alpha'], reg_lambda=params['reg_lambda'])
        # Running cross validation on your xgboost model

        kfold = KFold(n_splits=10, shuffle=True, random_state=69)
        results = cross_val_score(model, x, y, n_jobs=-1, cv=kfold)
        score = results.mean()

        # returns the loss on validation set
        loss = 1 - score
        #timed = int((time.time() - start) / 5)
        timed = time.time() - start
        return {'loss': loss, 'status': 'ok', 'time': timed}

    def objective_rf(params):
        #(self, estimators, max_depth, min_split, min_leaf, max_feat, eval=False):

        model = sk.ensemble.RandomForestClassifier(params['estimators'], max_depth=params['max_depth'],
                                                min_samples_split=params['min_split'],
                                                min_samples_leaf=params['min_leaf'],
                                                max_features=params['max_feat'])

        kfold = KFold(n_splits=10, shuffle=True, random_state=69)
        results = cross_val_score(model, x, y, n_jobs=-1, cv=kfold)
        score = results.mean()

        # returns the loss on validation set
        loss = 1 - score
        timed = time.time() - start
        return {'loss': loss, 'status': 'ok', 'time': timed}

    ### step 2 : defining the search space
    xgb_space = {
        'max_depth': hp.hp.choice('max_depth', np.arange(3, 15, 1, dtype=int)),
        'learning_rate': hp.hp.uniform('learning_rate', 0.001, 0.1),
        'n_estimators': hp.hp.choice('n_estimators', np.arange(50, 200, 1, dtype=int)),
        'gamma': hp.hp.uniform('gamma', 0.05, 1.0),
        'subsample': hp.hp.uniform('subsample', 0.6, 1.0),
        'colsample_bytree': hp.hp.uniform('colsample_bytree', 0.5, 1.0),
        'colsample_bylevel': hp.hp.uniform('colsample_bylevel', 0.5, 1.0),
        'min_child_weight': hp.hp.uniform('min_child_weight', 1, 7),
        'reg_alpha': hp.hp.uniform('reg_alpha', 0.0, 1.0),
        'reg_lambda': hp.hp.uniform('reg_lambda', 0.01, 1.0)}

    rf_space = {
        'estimators': hp.hp.choice('n_estimators', np.arange(50, 120, 1, dtype=int)),
        'max_depth': hp.hp.choice('max_depth', np.arange(3, 30, 1, dtype=int)),
        'min_split': hp.hp.choice('min_split', np.arange(2, 100, 1, dtype=int)),
        'min_leaf': hp.hp.choice('min_leaf', np.arange(1, 100, 1, dtype=int)),
        'max_feat': hp.hp.uniform('max_feat', 0.1, 0.9)}

    if algo == 'xgb':
        space = xgb_space
        objective = objective_xgb
    else:
        space = rf_space
        objective = objective_rf

    ### step 3 : storing the results of every iteration
    bayes_trials = Trials()
    MAX_EVALS = iterations

    # Optimize
    if mode == 'TPE':
        opt = hp.tpe.suggest
    else:
        opt = hp.rand.suggest


    print('Finding best parameters for dataset {} using {}'.format(dataname, mode))
    start = time.time()
    best = fmin(fn=objective, space=space, algo=opt,
                max_evals=MAX_EVALS, trials=bayes_trials)

    results = bayes_trials.results
    accuracies = []
    times = []
    for step in results:
        if maxtime > 0:
            if step['time'] <= maxtime:
                accuracies.append(1 - step['loss'])
                times.append(step['time'])
        else:
            accuracies.append(1 - step['loss'])
            times.append(step['time'])

    max_acc = 0
    for i in range(len(accuracies)):
        max_acc = max(max_acc, accuracies[i])
        accuracies[i] = max_acc

    if maxtime > 0:
        return accuracies, times


    return accuracies


def hypopt_rand_timed(): #TODO
    pass

def hypopt_ray(data, dataname, iterations=100, mode='BOHB', algo='xgb', maxtime=0):
    x, y = data
    maxtime = maxtime * 60

    ### Step 1 : defining the objective function
    def objective_xgb(params):

        model = xgb.XGBClassifier(max_depth=params['max_depth'], learning_rate=params['learning_rate'],
                                  n_estimators=params['n_estimators'], gamma=params['gamma'],
                                  min_child_weight=params['min_child_weight'], subsample=params['subsample'],
                                  colsample_bytree=params['colsample_bytree'],
                                  colsample_bylevel=params['colsample_bylevel'],
                                  reg_alpha=params['reg_alpha'], reg_lambda=params['reg_lambda'])
        # Running cross validation on your xgboost model

        kfold = KFold(n_splits=10, shuffle=True, random_state=69)
        results = cross_val_score(model, x, y, n_jobs=-1, cv=kfold)
        score = results.mean()

        # returns the loss on validation set
        timed = time.time() - start
        loss = 1 - score

    def objective_rf(params):
        #(self, estimators, max_depth, min_split, min_leaf, max_feat, eval=False):

        model = sk.ensemble.RandomForestClassifier(params['estimators'], max_depth=params['max_depth'],
                                                min_samples_split=params['min_split'],
                                                min_samples_leaf=params['min_leaf'],
                                                max_features=params['max_feat'])

        kfold = KFold(n_splits=10, shuffle=True, random_state=69)
        results = cross_val_score(model, x, y, n_jobs=-1, cv=kfold)
        score = results.mean()

        # returns the loss on validation set
        timed = time.time() - start
        loss = 1 - score
        session.report({"real_time": timed, "mean_loss": loss})

    ### step 2 : defining the search space
    xgb_space = {
        'max_depth': tune.randint(3, 15),
        'learning_rate': tune.uniform(0.001, 0.1),
        'n_estimators': tune.randint(50, 200),
        'gamma': tune.uniform(0.05, 1.0),
        'subsample': tune.uniform(0.6, 1.0),
        'colsample_bytree': tune.uniform(0.5, 1.0),
        'colsample_bylevel': tune.uniform(0.5, 1.0),
        'min_child_weight': tune.uniform(1, 7),
        'reg_alpha': tune.uniform(0.0, 1.0),
        'reg_lambda': tune.uniform(0.01, 1.0)}

    rf_space = {
        'estimators': tune.randint(50, 120),
        'max_depth': tune.randint(3, 30),
        'min_split': tune.randint(2, 100),
        'min_leaf': tune.randint(1, 100),
        'max_feat': tune.uniform(0.1, 0.9)}

    if algo == 'xgb':
        space = xgb_space
        objective = objective_xgb
    else:
        space = rf_space
        objective = objective_rf


    # Optimize
    ray.init(object_store_memory=200 * 1024 * 1024, num_cpus=1)
    start = time.time()

    if mode == 'BOHB':

        algo = TuneBOHB(metric='mean_loss', mode='min')
        bohb = HyperBandForBOHB()

        if maxtime > 0:
            tuner = tune.Tuner(
                objective,
                run_config=air.RunConfig(stop=tune.stopper.TimeoutStopper(maxtime), verbose=0),
                tune_config=tune.TuneConfig(
                    metric="mean_loss",
                    mode="min",
                    scheduler=bohb,
                    search_alg=algo,
                    num_samples=iterations
                ),
                param_space=space,
            )

        else:

            tuner = tune.Tuner(
                objective,
                run_config=air.RunConfig(verbose=0),
                tune_config=tune.TuneConfig(
                    metric="mean_loss",
                    mode="min",
                    scheduler=bohb,
                    search_alg=algo,
                    num_samples=iterations
                ),
                param_space=space,
            )

    else:

        skopt_search = SkOptSearch(
            metric="mean_loss",
            mode="min")

        if maxtime > 0:

            tuner = tune.Tuner(
                objective,
                run_config=air.RunConfig(stop=tune.stopper.TimeoutStopper(maxtime), verbose=0),
                tune_config=tune.TuneConfig(
                    search_alg=skopt_search,
                    metric="mean_loss",
                    mode="min",
                    num_samples=iterations
                ),
                param_space=space
            )

        else:
            tuner = tune.Tuner(
                objective,
                run_config=air.RunConfig(verbose=0),
                tune_config=tune.TuneConfig(
                    search_alg=skopt_search,
                    metric="mean_loss",
                    mode="min",
                    num_samples=iterations
                ),
                param_space=space
            )


    print('Finding best parameters for dataset {} using {}'.format(dataname, mode))
    results = tuner.fit()
    print('Best found accuracy: {}'.format(1-results.get_best_result().metrics['mean_loss']))

    df = results.get_dataframe()

    losses = np.asarray(df['mean_loss'])
    times = np.asarray(df['real_time'])

    accuracies = []
    ray.shutdown()
    for step in losses:
        accuracies.append(1 - step)

    max_acc = 0
    for i in range(len(accuracies)):
        max_acc = max(max_acc, accuracies[i])
        accuracies[i] = max_acc

    if maxtime > 0:
        return accuracies, times

    return accuracies


def hypopt_bohb_subsamp(data, dataname, iterations=100, mode='BOHB', algo='xgb', maxtime=0):
    x, y = data
    maxtime = maxtime * 60

    ### Step 1 : defining the objective function
    def objective_xgb(params):

        model = xgb.XGBClassifier(max_depth=params['max_depth'], learning_rate=params['learning_rate'],
                                  n_estimators=params['n_estimators'], gamma=params['gamma'],
                                  min_child_weight=params['min_child_weight'], subsample=params['subsample'],
                                  colsample_bytree=params['colsample_bytree'],
                                  colsample_bylevel=params['colsample_bylevel'],
                                  reg_alpha=params['reg_alpha'], reg_lambda=params['reg_lambda'])
        # Running cross validation on your xgboost model

        kfold = KFold(n_splits=10, shuffle=True, random_state=69)
        results = cross_val_score(model, x, y, n_jobs=-1, cv=kfold)
        score = results.mean()

        # returns the loss on validation set
        timed = time.time() - start
        loss = 1 - score
        session.report({"real_time": timed, "mean_loss": loss})
        #return {'mean_loss': loss}

    def objective_rf(params, x, y):
        #(self, estimators, max_depth, min_split, min_leaf, max_feat, eval=False):

        model = sk.ensemble.RandomForestClassifier(params['estimators'], max_depth=params['max_depth'],
                                                min_samples_split=params['min_split'],
                                                min_samples_leaf=params['min_leaf'],
                                                max_features=params['max_feat'])

        kfold = KFold(n_splits=10, shuffle=True, random_state=69)
        results = cross_val_score(model, x, y, n_jobs=-1, cv=kfold)
        score = results.mean()

        # returns the loss on validation set
        #timed = time.time() - start
        loss = 1 - score
        #session.report({"real_time": timed, "mean_loss": loss})
        return loss

    def objective_steps_rf(params):
        for i in range(params['folds']):
            perc = (i+1) / params['folds']
            x_it, _,  y_it, _ = train_test_split(x, y, test_size=perc, random_state=69)


            loss = objective_rf(params, x_it, y_it)
            #iter_loss = loss / ((i+1) / params['folds'])
            timed = time.time() - start
            session.report({"iterations": i+1, "real_time": timed, "mean_loss": loss})


    ### step 2 : defining the search space
    xgb_space = {
        'max_depth': tune.randint(3, 15),
        'learning_rate': tune.uniform(0.001, 0.1),
        'n_estimators': tune.randint(50, 200),
        'gamma': tune.uniform(0.05, 1.0),
        'subsample': tune.uniform(0.6, 1.0),
        'colsample_bytree': tune.uniform(0.5, 1.0),
        'colsample_bylevel': tune.uniform(0.5, 1.0),
        'min_child_weight': tune.uniform(1, 7),
        'reg_alpha': tune.uniform(0.0, 1.0),
        'reg_lambda': tune.uniform(0.01, 1.0)}

    rf_space = {
        'folds': 4,
        'estimators': tune.randint(50, 120),
        'max_depth': tune.randint(3, 30),
        'min_split': tune.randint(2, 100),
        'min_leaf': tune.randint(1, 100),
        'max_feat': tune.uniform(0.1, 0.9)}

    if algo == 'xgb':
        space = xgb_space
        objective = objective_xgb
    else:
        space = rf_space
        objective = objective_steps_rf


    # Optimize
    ray.init(object_store_memory=200 * 1024 * 1024, num_cpus=1)
    start = time.time()

    if mode == 'BOHB':

        algo = TuneBOHB(metric='mean_loss', mode='min')
        bohb = HyperBandForBOHB()#time_attr='real_time', max_t=maxtime)

        if maxtime > 0:
            tuner = tune.Tuner(
                objective,
                run_config=air.RunConfig(stop=tune.stopper.TimeoutStopper(maxtime), verbose=0),
                tune_config=tune.TuneConfig(
                    metric="mean_loss",
                    mode="min",
                    scheduler=bohb,
                    search_alg=algo,
                    num_samples=iterations
                ),
                param_space=space,
            )

        else:

            tuner = tune.Tuner(
                objective,
                run_config=air.RunConfig(name='bohb', verbose=0),
                tune_config=tune.TuneConfig(
                    metric="mean_loss",
                    mode="min",
                    scheduler=bohb,
                    search_alg=algo,
                    num_samples=iterations
                ),
                param_space=space,
            )

    else:

        skopt_search = SkOptSearch(
            metric="mean_loss",
            mode="min")

        tuner = tune.Tuner(
            objective,
            run_config=air.RunConfig(name="scikit-opt", verbose=0),
            tune_config=tune.TuneConfig(
                search_alg=skopt_search,
                metric="mean_loss",
                mode="min",
                num_samples=iterations
            ),
            param_space=space
        )


    print('Finding best parameters for dataset {} using {}'.format(dataname, mode))
    results = tuner.fit()
    print('Best found accuracy: {}'.format(1-results.get_best_result().metrics['mean_loss']))

    df = results.get_dataframe()

    losses = np.asarray(df['mean_loss'])
    times = np.asarray(df['real_time'])

    accuracies = []
    ray.shutdown()
    for step in losses:
        accuracies.append(1 - step)

    max_acc = 0
    for i in range(len(accuracies)):
        max_acc = max(max_acc, accuracies[i])
        accuracies[i] = max_acc

    if maxtime > 0:
        return accuracies, times

    return accuracies


def datasets_stats(data, datanames):
    for setnum in range(len(datanames)):
        name = datanames[setnum]
        print('Calculating statistics for dataset {}'.format(name))

        start = time.time()
        std_tot = 0
        x, y = data[setnum]
        env = xgb_env_cont(x, y, 11, lb=(3, 0.001, 50, 0.05, 1, 0.6, 0.5, 0.5, 0, 0.01),
                                   ub=(15, 0.1, 200, 1, 7, 1, 1, 1, 1, 1))

        rewards = np.zeros(100)
        for i in range(100):
            action = np.random.uniform(low=-1, high=1, size=10)
            real_action = np.zeros(10)
            action_scaled = np.squeeze((action + 1) / 2)
            for j in range(10):
                real_action[j] = env.lb[j] + action_scaled[j] * (env.ub[j] - env.lb[j])

            reward, std = env.score_RF(real_action[0], real_action[1], real_action[2],
                                   real_action[3], real_action[4], real_action[5],
                                   real_action[6], real_action[7], real_action[8], real_action[9], std=True)

            rewards[i] = reward
            std_tot += std

        avg_reward = np.mean(rewards)
        reward_std = np.std(rewards)
        stop = time.time()
        timetaken = stop-start
        mins = int(timetaken/60)
        secs = timetaken % 60
        print('Average reward: {}, reward std: {}, average std: {}'.format(avg_reward, reward_std, (std_tot/100)))
        print('Min reward: {}, max reward: {}'.format(np.min(rewards), np.max(rewards)))
        print('Time taken: {} minutes and {} seconds'.format(mins, secs))
        print()

def plot_results(scores_list, epnums, names):
    for i in range(len(names)):
        pass


def main(env_name, data, dataname, seed=69, use_decay=True, gamma=0.99, tau=0.005,
         policy='Gaussian', target_update_interval=2, automatic_entropy_tuning=False,
         hidden_size=256, lr=0.001, num_networks=7, num_elites=5,
         pred_hidden_size=200, reward_size=1, replay_size=1000000,
         model_retain_epochs=1, model_train_freq=250,
         rollout_batch_size=100000, epoch_length=1000, rollout_min_epoch=20,
         rollout_max_epoch=150, rollout_min_length=1,
         rollout_max_length=15, num_epoch=1000, min_pool_size=1000,
         real_ratio=0.05, train_every_n_steps=1, num_train_repeat=20,
         policy_train_batch_size=256, init_exploration_steps=5000,
         max_path_length=1000, model_type='pytorch',
         cuda=False, emb_size=5, train_emb=True,
         save_model=False, load_model=False, save_name='x', load_name='x', test_mode=False,
         act_noise=0.1, target_noise=0.2, noise_clip=0.5, max_model_error=0.2, save_graphs=False,
         save_buffer=False, load_buffer=False, eval_freq=5, save_exploration=False, load_exploration=False,
         emb_name='EmbedNet', meta_training=False, eval_mode=False, action_norm=True, per=True, show_plots=True,
         epsilon=0.0, maxtime=0, save_raw_exp=False, pretrain_dyn_model=False):
    #TODO add act_noise, target_noise, noise_clip, policy_update

    args = ArgumentReader(env_name, data, dataname, seed, use_decay, gamma, tau, policy, target_update_interval,
            automatic_entropy_tuning, hidden_size, lr, num_networks, num_elites, pred_hidden_size, reward_size,
            replay_size, model_retain_epochs, model_train_freq, rollout_batch_size, epoch_length,
            rollout_min_epoch, rollout_max_epoch, rollout_min_length, rollout_max_length, num_epoch, min_pool_size,
            real_ratio, train_every_n_steps, num_train_repeat, policy_train_batch_size,
            init_exploration_steps, max_path_length, model_type, cuda, emb_size, train_emb,
            save_model, load_model, save_name, load_name, test_mode, act_noise, target_noise, noise_clip,
            max_model_error, save_graphs, save_buffer, load_buffer, eval_freq, save_exploration, load_exploration,
            emb_name, meta_training, eval_mode, action_norm, per, show_plots, epsilon, maxtime, save_raw_exp,
                          pretrain_dyn_model)

    scorelist = []
    epslist = []

    if args.eval_mode:
        iterations = 3

    else:
        iterations = 1

    for iter in range(iterations):

        x, y = data

        if args.env_name == 'rf':
            #if args.test_mode:
            env = rf_env_cont(x, y, 6, lb=[50, 3, 2, 1, 0.1], ub=[
                              120, 30, 100, 100, 0.9])
            # else:
            #     env = rf_env_cont(x, y, 6, lb=[100, 3, 2, 1, 0.1], ub=[
            #                     1200, 30, 100, 100, 0.9])
        elif args.env_name == 'xgboost':
            #if args.test_mode:
            env = xgb_env_cont(x, y, 11, lb=(3, 0.001, 50, 0.05, 1, 0.6, 0.5, 0.5, 0, 0.01),
                               ub=(15, 0.1, 200, 1, 7, 1, 1, 1, 1, 1))
            # else:
            #     env = xgb_env_cont(x, y, 11, lb=(3, 0.001, 50, 0.05, 1, 0.6, 0.5, 0.5, 0, 0.01),
            #          ub=(25, 0.1, 1200, 1, 7, 1, 1, 1, 1, 1))
        else:
            env = gym.make(args.env_name)


        # Intial agent
        agent = SAC(env.observation_space.shape[0], env.action_space, args)

        state_size = args.emb_size
        action_size = np.prod(env.action_space.shape)
        args.action_size = action_size

        rollouts_per_epoch = args.rollout_batch_size * args.epoch_length / args.model_train_freq
        model_steps_per_epoch = int(1 * rollouts_per_epoch)
        new_pool_size = args.model_retain_epochs * model_steps_per_epoch

        # Dynamics models
        if args.model_type == 'pytorch':
            if args.load_buffer and not args.pretrain_dyn_model:
                env_model = EnsembleDynamicsModel(args.num_networks, args.num_elites, state_size, action_size,
                                                  args.reward_size, args.pred_hidden_size,
                                                  use_decay=args.use_decay, no_transfer=False)
                env_model = load_dynamics_model(args, env_model)

            else:
                env_model = EnsembleDynamicsModel(args.num_networks, args.num_elites, state_size, action_size,
                                                  args.reward_size, args.pred_hidden_size,
                                                  use_decay=args.use_decay)
        else:
            env_model = construct_model(obs_dim=state_size, act_dim=action_size, hidden_dim=args.pred_hidden_size,
                                        num_networks=args.num_networks,
                                        num_elites=args.num_elites)


        # Predict environments
        predict_env = PredictEnv(env_model, args.env_name, args.model_type)


        # Environment pools, model pools and environment samplers
        if args.per:
            env_pool = ReplayMemoryPER(args.replay_size)
        else:
            env_pool = ReplayMemory(args.replay_size)
        model_pool = ReplayMemory(new_pool_size)
        env_sampler = EnvSampler(env, max_path_length=args.max_path_length)

        print('creating replay buffer of size {}'.format(args.replay_size))

        print('start training')
        scores, epnums = train(args, env_sampler, predict_env, agent, env_pool, model_pool, action_size)

        if not args.eval_mode:
            return scores, epnums

        else:
            scorelist.append(scores)
            epslist.append(epnums)

    return scorelist, epslist


if __name__ == '__main__':

    random.seed(69)
    good_data_names = get_data_paths('data/good_data/')
    datanames_list = list(good_data_names.keys())

    #testnames = ['TeachingAssistent', 'MaternalHealthRisk', 'CervicalCancer', 'Obesity', 'Musk', 'AndroidMalware']
    #testnames = ['Obesity', 'Musk', 'AndroidMalware']
    testnames = ['TeachingAssistent', 'Musk', 'AndroidMalware']
    #testnames = ['CervicalCancer'] #'Musk'
    #testnames = ['TeachingAssistent']# ,'MaternalHealthRisk', 'Musk'] #teach
    #testnames = ['Musk']
    #testnames = ['TeachingAssistent']
    #testnames = ['MaternalHealthRisk']
    #testnames = ['CervicalCancer']
    #testnames = ['AndroidMalware']
    #testnames = ['Cervicalcancer', 'Musk', 'Obesity', 'AndroidMalware']
    #testnames = ['AndroidMalware', 'Obesity']
    #testnames = ['CervicalCancer']
    #test_subset = {key: datasets[key] for key in testnames}
    #print(test_subset)

    random.shuffle(datanames_list)
    trainnames = datanames_list
    for name in testnames: trainnames.remove(name)
    #testdata = datanames_list[15:]
    #print(testdata)

    #train_subset = {'CarEvaluation': datasets['CarEvaluation']}
    #train_subset = {key: datasets[key] for key in trainnames}
    #loadname = 'TeachingAssistent'
    #name = 'HeartDisease'

    # for name in trainnames:
    #     hypopt_TPE(train_subset[name], name, 2000)

    #datasets_stats(datasets)

    ################## TRAINING #########################
    #xgboost emb_name = 'EmbedNet10_mean_max_newtest'
    for name in trainnames:
        print('loading {} dataset'.format(name))
        data = load_dataset(good_data_names[name])

        #continue
        if name == trainnames[0]:
            main('rf', data, name, init_exploration_steps=1000, epoch_length=10,
                 model_train_freq=10, max_path_length=200, emb_size=10, num_epoch=500, min_pool_size=10,
                 rollout_min_epoch=1, #load_name=str(loadname + '10_gam09_norm'),
                 rollout_max_epoch=20, rollout_min_length=1, rollout_max_length=1, policy_train_batch_size=50,
                 rollout_batch_size=10000, real_ratio=1, hidden_size=256, replay_size=500000, save_model=True,
                 load_model=False, train_emb=False, gamma=0.99, max_model_error=0.02, act_noise=0.1, target_noise=0.1,
                 lr=0.00005, save_exploration=True, tau=0.05,
                 save_graphs=True, load_buffer=False, save_buffer=True, train_every_n_steps=5,
                 save_name=str(name + '10_gam099_acc_custq_PER_final_rf'), meta_training=True,
                 load_exploration=True, emb_name='EmbedNet10_mean_max_newtest_rf', eval_freq=5, show_plots=True,
                 save_raw_exp=False)

        else:
            main('rf', data, name, init_exploration_steps=1000, epoch_length=10,
                 model_train_freq=10, max_path_length=200, emb_size=10, num_epoch=500, min_pool_size=10,
                 rollout_min_epoch=1,  load_name=loadname,
                 rollout_max_epoch=20, rollout_min_length=1, rollout_max_length=1, policy_train_batch_size=50,
                 rollout_batch_size=10000, real_ratio=1, hidden_size=256, replay_size=500000, save_model=True,
                 load_model=True, train_emb=False, gamma=0.99, max_model_error=0.02, act_noise=0.1, target_noise=0.1,
                 lr=0.00005, save_exploration=True, tau=0.05,
                 save_graphs=True, load_buffer=False, save_buffer=True, train_every_n_steps=5,
                 save_name=str(name + '10_gam099_acc_custq_PER_final_rf'), meta_training=True,
                 load_exploration=True, emb_name='EmbedNet10_mean_max_newtest_rf', eval_freq=5, show_plots=False,
                 save_raw_exp=False)

        loadname = str(name + '10_gam099_acc_custq_PER_final_rf')