import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.neighbors import NearestNeighbors
from utils import *
from matplotlib import pyplot as plt
from scipy.stats import spearmanr, pearsonr
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.neural_network import MLPRegressor
import multiprocessing as mp
from sklearn.neighbors import NearestNeighbors
import tqdm
rnn = tf.contrib.rnn
slim = tf.contrib.slim
tfd = tfp.distributions

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
session_config = tf.ConfigProto(log_device_placement=False)
session_config.gpu_options.allow_growth = True

def seq_length(sequence):
    used = tf.sign(tf.reduce_max(tf.abs(sequence), 2))
    length = tf.reduce_sum(used, 1)
    length = tf.cast(length, tf.int32)
    return length

def seq_length_np(sequence):
    used = np.sign(np.max(np.abs(sequence), 2))
    length = np.sum(used, 1)
    return length


#########################  Hyper-parameters  ################################

rl_params = {
    # Pull out the offline trajectories collected using behavioral policies
    # that only attain less than 60% of the best possible performance
    'env_name' : "low",
    # Define the number of neighbors used to regularize IHR reconstruction
    'n_neighbours' : 20,
}

gamma = .995

batch_size = 64

start_learning_rate = .001
decay_step = 1000
decay_rate = .95
num_hidden = 128

C = .001

training_steps = 6000

############################################################################

# Load the tsne embeddings for calculating K-neighbors

tsne_result_path = "tsne/OPE_SAC_latentPolicy_lstm_zt_zt1_d4rlOnly_low_3000epi_repeat1_0.0003_1000_0.95_64_0.05_2599.csv"

tsne_results = pd.read_csv(tsne_result_path, index_col = 0)

# Maximum possible dimension for the action (from transformer's output),
# as the questions asked by the agent will be truncated if contains more than 40 words,
# as per ILQL's original implementation

env_action_dim_ori = 47616

dataset = np.load("./{}.npy".format(
    rl_params['env_name']), allow_pickle=True).item()

# Shift the percentile ranking to be positive

for i in range(len(dataset['final_rewards'])):
    dataset['final_rewards'][i][-1] += 2.5

env_action_dim = 768 * int(np.percentile([j.shape[0] for i in dataset['actions'] for j in i], 60))

max_seq_len = max([len(dataset['observations'][i]) \
                   for i in range(len(dataset['observations'])) ])


train_test_split_thres = int(len(dataset['observations'])*.8)

batch_size_val = int(len(dataset['observations'])) - train_test_split_thres

neighbor_dict = dict(zip(tsne_results["sas2"].values, tsne_results[["{}neighbor_rew_{}".format(rl_params['n_neighbours'], j) for j in range(rl_params['n_neighbours'])]].values))



def train():

    SEED = 2599

    num_input = dataset['observations'][0][0].shape[0]*2 + env_action_dim \
        + 2 # per-step reward and human feedback

    timesteps = max_seq_len # timesteps

    display_step = 50

    np.random.seed(SEED)
    tf.set_random_seed(SEED)
    random.seed(SEED)


    file_appendix = "OPEHF_{}_new_".format(rl_params['env_name']) + str(start_learning_rate) + "_" \
        + str(decay_step) + "_" + str(decay_rate) + "_" + str(num_hidden) + "_" + str(C) 



    def gen_train():
        for i in range(len(dataset['observations'])):
            if i < train_test_split_thres:
                human_feedback = np.ones_like(np.asarray(dataset['rewards'][i]).reshape(-1,1)) \
                    * dataset['final_rewards'][i][-1] * 1.
                out = np.hstack([
                    dataset['observations'][i], 
                    np.asarray([
                        np.concatenate(
                            [j.reshape(-1), np.zeros(env_action_dim_ori-j.reshape(-1).shape[0])]
                        )  for j in dataset['actions'][i]]),
                    dataset['next_observations'][i],
                    np.asarray(dataset['rewards'][i]).reshape(-1,1),
                    human_feedback,
                ])
                neighbor_rewards = np.asarray([
                    neighbor_dict[str(sas2.tolist())] for sas2 in out[:,:-2]
                ])  
                
                out = np.hstack([
                    dataset['observations'][i], 
                    np.asarray([
                        np.concatenate(
                            [j.reshape(-1), np.zeros(env_action_dim-j.reshape(-1).shape[0])]
                        ) if len(j.reshape(-1)) < env_action_dim else j.reshape(-1)[:env_action_dim]\
                            for j in dataset['actions'][i]]),
                    dataset['next_observations'][i],
                    np.asarray(dataset['rewards'][i]).reshape(-1,1),
                    human_feedback,
                ])
                
                len_indicator = np.ones_like(human_feedback)


                out = np.vstack([
                    out, 
                    np.zeros((max_seq_len-out.shape[0], out.shape[1]))
                ])
                neighbor_rewards = np.vstack([
                    neighbor_rewards, 
                    np.zeros((max_seq_len-neighbor_rewards.shape[0], neighbor_rewards.shape[1]))
                ]) 
                len_indicator = np.vstack([
                    len_indicator,
                    np.zeros((max_seq_len-len_indicator.shape[0], 1))
                ])
                yield (out.astype(np.float32), 
                       [np.float32(dataset['final_rewards'][i][-1]) * 1.], 
                       neighbor_rewards.astype(np.float32),
                       len_indicator
                      )

    def gen_test():
    
        for i in range(len(dataset['observations'])):
            if i >= train_test_split_thres:
                human_feedback = np.ones_like(np.asarray(dataset['rewards'][i]).reshape(-1,1)) \
                    * dataset['final_rewards'][i][-1] * 1.
                out = np.hstack([
                    dataset['observations'][i], 
                    np.asarray([
                        np.concatenate(
                            [j.reshape(-1), np.zeros(env_action_dim_ori-j.reshape(-1).shape[0])]
                        ) for j in dataset['actions'][i]]),
                    dataset['next_observations'][i],
                    np.asarray(dataset['rewards'][i]).reshape(-1,1),
                    human_feedback,
                ])
                neighbor_rewards = np.asarray([
                    neighbor_dict[str(sas2.tolist())] for sas2 in out[:,:-2]
                ])  
                
                out = np.hstack([
                    dataset['observations'][i], 
                    np.asarray([
                        np.concatenate(
                            [j.reshape(-1), np.zeros(env_action_dim-j.reshape(-1).shape[0])]
                        ) if len(j.reshape(-1)) < env_action_dim else j.reshape(-1)[:env_action_dim]\
                            for j in dataset['actions'][i]]),
                    dataset['next_observations'][i],
                    np.asarray(dataset['rewards'][i]).reshape(-1,1),
                    human_feedback,
                ])
                
                len_indicator = np.ones_like(human_feedback)

                out = np.vstack([
                    out, 
                    np.zeros((max_seq_len-out.shape[0], out.shape[1]))
                ])
                neighbor_rewards = np.vstack([
                    neighbor_rewards, 
                    np.zeros((max_seq_len-neighbor_rewards.shape[0], neighbor_rewards.shape[1]))
                ]) 
                len_indicator = np.vstack([
                    len_indicator,
                    np.zeros((max_seq_len-len_indicator.shape[0], 1))
                ])
                yield (out.astype(np.float32), 
                       [np.float32(dataset['final_rewards'][i][-1]) * 1.], 
                       neighbor_rewards.astype(np.float32),
                       len_indicator
                      )

    graph = tf.Graph()


    def build_net(x, graph=graph, reuse=tf.AUTO_REUSE, is_training=True, var_scope="bilstm"):

        with graph.as_default():
            with tf.variable_scope(var_scope, reuse=reuse) as scope:

                seq_len = seq_length(x)
                x = tf.unstack(x, timesteps, 1)
                lstm_cell_fw = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0, reuse=reuse)
                lstm_cell_bw = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0, reuse=reuse)
                outputs, _, _ = tf.nn.static_bidirectional_rnn(
                    lstm_cell_fw, lstm_cell_bw, x, 
                    sequence_length=seq_len,
                    dtype=tf.float32, 
                )


                with slim.arg_scope([slim.fully_connected], 
                                    activation_fn=tf.nn.relu,
                                    weights_initializer=tf.glorot_uniform_initializer,
                                    weights_regularizer=slim.l2_regularizer(0.001),
                                    biases_regularizer=slim.l2_regularizer(0.001),
                                    normalizer_fn = None,
                                    reuse = reuse,
                                    scope = scope):
                    fc1 = slim.fully_connected(outputs, 128, scope="fc1")
                    fc2 = slim.fully_connected(fc1, 64, scope="fc2")
                    loc = slim.fully_connected(fc2, 1, activation_fn=None, normalizer_fn=None, weights_regularizer=None, biases_regularizer=None, biases_initializer=None, scope="loc")
                    scale = slim.fully_connected(fc2, 1, activation_fn=tf.nn.softplus, normalizer_fn=None, weights_regularizer=None, biases_regularizer=None, biases_initializer=None, scope="scale")
                    dist = tfd.MultivariateNormalDiag(loc, scale)
                    return dist, dist.sample(), seq_len, outputs, loc, scale


    with graph.as_default():

        dataset_train = tf.data.Dataset.from_generator(
            gen_train, 
            (tf.float32, tf.float32, tf.float32, tf.float32), 
            ([ timesteps, num_input], [1], [ timesteps, rl_params['n_neighbours']], [ timesteps, 1])
        ).repeat(100).shuffle(500).batch(batch_size)

        dataset_test = tf.data.Dataset.from_generator(
            gen_test, 
            (tf.float32, tf.float32, tf.float32, tf.float32), 
            ([ timesteps, num_input], [1], [ timesteps, rl_params['n_neighbours']], [ timesteps, 1])
        ).repeat(800).batch(batch_size_val)

        (input_train, 
         humanfeedback_train, 
         neighbor_rewards_train, 
         len_indicator_train) = dataset_train.make_one_shot_iterator().get_next()

        (input_test, 
         humanfeedback_test, 
         neighbor_rewards_test, 
         len_indicator_test) = dataset_test.make_one_shot_iterator().get_next()

        dists, samples, seq_lens, outputs, locs, scales = build_net(input_train)
        dists_val, samples_val, seq_lens_val, outputs_val, locs_val, scales_val = build_net(input_test, is_training=False)


        gammas = tf.repeat(tf.constant([[[gamma**i] for i in range(max_seq_len)]], dtype=tf.float32), batch_size, 0)
        gammas = tf.multiply(gammas, len_indicator_train)
        gammas = tf.stack(tf.unstack(gammas, timesteps, 1), 0)


        gammas_val = tf.repeat(tf.constant([[[gamma**i] for i in range(max_seq_len)]], dtype=tf.float32), batch_size_val, 0)
        gammas_val = tf.multiply(gammas_val, len_indicator_test)
        gammas_val = tf.stack(tf.unstack(gammas_val, timesteps, 1), 0)

        sum_locs = tf.reduce_sum(tf.multiply(locs, gammas), [0])
        sum_scales = tf.sqrt(tf.reduce_sum(tf.multiply(tf.square(scales), tf.square(gammas)), 0))

        sum_locs_val = tf.reduce_sum(tf.multiply(locs_val, gammas_val), [0])

        returns_dist = tfd.MultivariateNormalDiag(sum_locs, sum_scales)
        returns_dist_logprob = returns_dist.log_prob(humanfeedback_train)

        reward_out_dists_logprobs = dists.log_prob(tf.split(tf.stack(tf.unstack(neighbor_rewards_train, timesteps, 1), 0), rl_params['n_neighbours'], 2))
        reward_out_dists_logprobs = tf.multiply(tf.reduce_sum(dists.log_prob(tf.split(tf.stack(tf.unstack(neighbor_rewards_train, timesteps, 1), 0), rl_params['n_neighbours'], 2)), 0)[..., tf.newaxis], tf.stack(tf.unstack(len_indicator_train, timesteps, 1), 0))


        sum_likelihood = tf.reduce_mean(returns_dist_logprob)

        per_step_likelihood = tf.reduce_mean(tf.reduce_sum(reward_out_dists_logprobs, 0))

        loss = - sum_likelihood - C * per_step_likelihood \
            + tf.reduce_mean(
                tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            )

        mae_val = tf.abs((sum_locs_val - humanfeedback_test) / (tf.abs(humanfeedback_test) + 1e-04) )

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        
        global_step = tf.Variable(0., trainable=False, name="training_step")
        global_step_increment = global_step.assign(global_step+1)
        learning_rate = tf.train.exponential_decay(start_learning_rate, global_step, decay_steps=decay_step, decay_rate=decay_rate)

        with tf.control_dependencies(update_ops + [loss]):
            optimize = tf.train.AdamOptimizer(learning_rate)

            optimize_gradients = optimize.compute_gradients(loss)
            optimize_clipped_gradients = [
                (tf.clip_by_value(grad, -10., 10.), var)
                if (var.name.find("scale")!=-1)
                else (grad, var)
                for (grad, var) in optimize_gradients
            ]
            optimizer = optimize.apply_gradients(optimize_clipped_gradients)


        saver = tf.train.Saver(var_list = [v for v in tf.trainable_variables()])


    with tf.Session(config=session_config, graph=graph) as sess:
        if os.path.exists("./stats/"+file_appendix+".txt"):
            sess.run(tf.global_variables_initializer())
            saver.restore(sess, "./saved_model/" + file_appendix + "/opehf.ckpt")
            df = pd.read_csv("./stats/"+file_appendix+".txt", delimiter = " | ", header=None)
            for m in range(df[1].max()):
                sess.run(global_step_increment)
            starting_iter = df[1].max()
        else:
            sess.run(tf.global_variables_initializer())
            starting_iter = 0
            
        best_mae_val = 9999999.
        best_rank_val = -9999999.
        for i in range(starting_iter, training_steps):
            sess.run(optimizer)
            if (i+1) % display_step == 0:
                (train_sum_likelihood, 
                 train_per_step_likelihood, 
                 train_loss
                ) = sess.run([sum_likelihood, per_step_likelihood, loss])

                # Identify and save the model acheiving the best MAE and rank over the hold-out validation set

                current_mae_val, current_sums_locs_val, current_hf_val = sess.run([mae_val, sum_locs_val, humanfeedback_test])
                rank_val = spearmanr(current_sums_locs_val, current_hf_val)[0]
                saver.save(sess, "./saved_model/" + file_appendix + "/opehf.ckpt")
                if rank_val > best_rank_val:
                    best_rank_val = rank_val
                    best_mae_val = np.mean(current_mae_val)
                    saver.save(sess, "./saved_model/" + file_appendix + "/opehf_best.ckpt")
                print("Step: {} | Train Loss: {} | Train Sum Likelihood : {} | Train Per-Step LikeliHood : {} | MAE_val: {} | RANK_val: {} | Best MAE_val: {} | Best RANK_val: {} \n".format(
                    i+1,
                    train_loss,
                    train_sum_likelihood,
                    train_per_step_likelihood,
                    np.mean(current_mae_val), 
                    rank_val,
                    best_mae_val,
                    best_rank_val
                ))
                with open("./stats/"+file_appendix+".txt", "a") as myfile:
                    myfile.write("Step: {} | Train Loss: {} | Train Sum Likelihood : {} | Train Per-Step LikeliHood : {} | MAE_val: {} | RANK_val: {} | Best MAE_val: {} | Best RANK_val: {} \n".format(
                    i+1,
                    train_loss,
                    train_sum_likelihood,
                    train_per_step_likelihood,
                    np.mean(current_mae_val), 
                    rank_val,
                    best_mae_val,
                    best_rank_val
                ))

if __name__ == "__main__":
    train()

