import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.neighbors import NearestNeighbors
from utils import *
from matplotlib import pyplot as plt
from scipy.stats import spearmanr, pearsonr
from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPRegressor
from copy import deepcopy
import pickle

rnn = tf.contrib.rnn
slim = tf.contrib.slim
tfd = tfp.distributions

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
session_config = tf.ConfigProto(log_device_placement=False)
session_config.gpu_options.allow_growth = True

rl_params = {
    'env_name' : "low",
    'n_neighbours' : 20,
}

##################################### Config ###################################################

# File that stores SNE similarities
tsne_result_path = "./tsne/VLM-H_3000epi_repeat1_0.0003_1000_0.95_64_0.05_2599.csv"

# RILR ckpt to load for reconstructing IHRs
ckpt_path = "./saved_model/OPEHF_low_new_0.00101_500_0.9_128_0.001/opehf_best.ckpt"


################################################################################################

tsne_results = pd.read_csv(tsne_result_path, index_col = 0)

dataset = np.load("./{}.npy".format(
    rl_params['env_name']), allow_pickle=True).item()

env_action_dim = 768 * int(np.percentile([j.shape[0] for i in dataset['actions'] for j in i], 60))


max_seq_len = max([len(dataset['observations'][i]) \
                   for i in range(len(dataset['observations'])) ])


num_input = dataset['observations'][0][0].shape[0]*2 + env_action_dim \
        + 2 # per-step reward and human feedback

timesteps = max_seq_len # timesteps

gamma = .995

UPSAMPLE_SIZE = 1

for i in range(len(dataset['final_rewards'])):
    dataset['final_rewards'][i][-1] *= 100
    dataset['final_rewards'][i][-1] += 25


if "batchnorm" in ckpt_path:
    num_hidden = int(ckpt_path.split("/")[2].split("_")[7])
elif "relu" in ckpt_path:
    num_hidden = int(ckpt_path.split("/")[2].split("_")[7])
else:
    num_hidden = int(ckpt_path.split("/")[2].split("_")[6])

def seq_length(sequence):
    used = tf.sign(tf.reduce_max(tf.abs(sequence), 2))
    length = tf.reduce_sum(used, 1)
    length = tf.cast(length, tf.int32)
    return length

graph = tf.Graph()


def build_net(x, graph=graph, reuse=tf.AUTO_REUSE, is_training=True, var_scope="bilstm"):

    with graph.as_default():
        with tf.variable_scope(var_scope, reuse=reuse) as scope:

            seq_len = seq_length(x)
            x = tf.unstack(x, timesteps, 1)
            lstm_cell_fw = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0, reuse=reuse)
            lstm_cell_bw = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0, reuse=reuse)
            outputs, _, _ = tf.nn.static_bidirectional_rnn(
                lstm_cell_fw, lstm_cell_bw, x, 
                sequence_length=seq_len,
                dtype=tf.float32, 
            )

            with slim.arg_scope([slim.fully_connected], 
                                activation_fn=tf.nn.relu,
                                weights_initializer=tf.glorot_uniform_initializer,
                                weights_regularizer=slim.l2_regularizer(0.001),
                                biases_regularizer=slim.l2_regularizer(0.001),
                                normalizer_fn = None,
                                reuse = reuse,
                                scope = scope):
                fc1 = slim.fully_connected(outputs, 128, scope="fc1")
                fc2 = slim.fully_connected(fc1, 64, scope="fc2")
                if "relu" in ckpt_path:
                    loc = slim.fully_connected(fc2, 1, activation_fn=tf.nn.relu, normalizer_fn=None, weights_regularizer=None, biases_regularizer=None, biases_initializer=None, scope="loc")
                else:
                    loc = slim.fully_connected(fc2, 1, activation_fn=None, normalizer_fn=None, weights_regularizer=None, biases_regularizer=None, biases_initializer=None, scope="loc")
                scale = slim.fully_connected(fc2, 1, activation_fn=tf.nn.softplus, normalizer_fn=None, weights_regularizer=None, biases_regularizer=None, biases_initializer=None, scope="scale")
                dist = tfd.MultivariateNormalDiag(loc, scale)
                return dist, loc, seq_len, outputs, loc, scale
            
def build_net_bc(x, graph=graph, reuse=tf.AUTO_REUSE, is_training=True, var_scope="bilstm"):

    with graph.as_default():
        with tf.variable_scope(var_scope, reuse=reuse) as scope:

            seq_len = seq_length(x)
            x = tf.unstack(x, timesteps, 1)
            lstm_cell_fw = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0, reuse=reuse)
            lstm_cell_bw = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0, reuse=reuse)
            outputs, _, _ = tf.nn.static_bidirectional_rnn(
                lstm_cell_fw, lstm_cell_bw, x, 
                sequence_length=seq_len,
                dtype=tf.float32, 
            )

            with slim.arg_scope([slim.fully_connected], 
                                activation_fn=tf.nn.relu,
                                weights_initializer=tf.glorot_uniform_initializer,
                                weights_regularizer=slim.l2_regularizer(0.001),
                                biases_regularizer=slim.l2_regularizer(0.001),
                                normalizer_fn = slim.batch_norm,
                                normalizer_params = {"is_training": is_training},
                                reuse = reuse,
                                scope = scope):
                fc1 = slim.fully_connected(outputs, 128, scope="fc1")
                fc2 = slim.fully_connected(fc1, 64, scope="fc2")
                loc = slim.fully_connected(fc2, 1, activation_fn=None, normalizer_fn=None, weights_regularizer=None, biases_regularizer=None, biases_initializer=None, scope="loc")
                scale = slim.fully_connected(fc2, 1, activation_fn=tf.nn.softplus, normalizer_fn=None, weights_regularizer=None, biases_regularizer=None, biases_initializer=None, scope="scale")
                dist = tfd.MultivariateNormalDiag(loc, scale)
                return dist, loc, seq_len, outputs, loc, scale

with graph.as_default():
    input_test = tf.placeholder(shape=[1, timesteps, num_input],  dtype=tf.float32)
    len_indicator_test = tf.placeholder(shape=[1, timesteps, 1],  dtype=tf.float32)
    input_len = tf.placeholder(shape=[],  dtype=tf.int32)
    
    if "batchnorm" not in ckpt_path:
        dists_val, samples_val, seq_lens_val, outputs_val, locs_val, scales_val = build_net(input_test, is_training=False)
    else:
        dists_val, samples_val, seq_lens_val, outputs_val, locs_val, scales_val = build_net_bc(input_test, is_training=False)
    
    gammas_val = tf.repeat(tf.constant([[[gamma**i] for i in range(max_seq_len)]], dtype=tf.float32), 1, 0)
    gammas_val = tf.multiply(gammas_val, len_indicator_test)
    gammas_val = tf.stack(tf.unstack(gammas_val, timesteps, 1), 0)
    
    dists_val_samples = tf.gather(tf.squeeze(samples_val), tf.range(input_len))
    
    saver = tf.train.Saver()
    
    with tf.Session(config=session_config, graph=graph) as sess:
    
        saver.restore(sess, ckpt_path)

        rs = [[] for _ in range(UPSAMPLE_SIZE)]
        for i in range(len(dataset['observations'])):
            human_feedback = np.ones_like(np.asarray(dataset['rewards'][i]).reshape(-1,1)) \
                * dataset['final_rewards'][i][-1] * 1.
            out = np.hstack([
                dataset['observations'][i], 
                np.asarray([
                    np.concatenate(
                        [j.reshape(-1), np.zeros(env_action_dim-j.reshape(-1).shape[0])]
                    ) if len(j.reshape(-1)) < env_action_dim else j.reshape(-1)[:env_action_dim]\
                        for j in dataset['actions'][i]]),
                dataset['next_observations'][i],
                np.asarray(dataset['rewards'][i]).reshape(-1,1),
                human_feedback,
            ])
            len_indicator = np.ones_like(human_feedback)
            inputlen = len(human_feedback)

            out = np.vstack([
                out, 
                np.zeros((max_seq_len-out.shape[0], out.shape[1]))
            ])
            len_indicator = np.vstack([
                len_indicator,
                np.zeros((max_seq_len-len_indicator.shape[0], 1))
            ])
            for j in range(UPSAMPLE_SIZE):
                rs[j] += [
                    sess.run(
                        dists_val_samples,
                        feed_dict = {
                            input_test : [out],
                            len_indicator_test : [len_indicator],
                            input_len : inputlen,
                        }
                    )
                ]
                
        dataset['rewards'] = np.concatenate(rs)


env_action_dim = 768 * int(np.percentile([j.shape[0] for i in dataset['actions'] for j in i], 60))

dataset['observations'] = np.concatenate([
    np.vstack(
        [np.asarray(i) for i in dataset['observations']]
    ) for _ in range(UPSAMPLE_SIZE)
])
dataset['next_observations'] = np.concatenate([
    np.vstack(
        [np.asarray(i) for i in dataset['next_observations']]
    ) for _ in range(UPSAMPLE_SIZE)
])
dataset['actions'] = np.concatenate([
    np.vstack(
        [
            np.concatenate(
                [np.asarray(j).reshape(-1), 
                 np.zeros(env_action_dim-np.asarray(j).reshape(-1).shape[0])]
            ) if np.asarray(j).reshape(-1).shape[0]<env_action_dim else np.asarray(j).reshape(-1)[:env_action_dim] for i in dataset['actions'] for j in i
        ]
    ) for _ in range(UPSAMPLE_SIZE)
])
dataset['terminals'] = np.concatenate([
    np.concatenate(
        [np.asarray(i) for i in dataset['terminals']]
    ) for _ in range(UPSAMPLE_SIZE)
])
dataset['rewards'] = np.concatenate([np.concatenate(dataset['rewards'])])

new_dataset = dict()
new_dataset['observations'] = np.concatenate([dataset['observations']])
new_dataset['next_observations'] = np.concatenate([dataset['next_observations']])
new_dataset['rewards'] = np.concatenate([dataset['rewards']])
new_dataset['actions'] = np.concatenate([dataset['actions']])
new_dataset['terminals'] = np.concatenate([dataset['terminals']])
new_dataset['timeouts'] = np.zeros_like(new_dataset['rewards'])


with open("./d4rl_data_from_opehf/d4rl_typed_data_from_opehf_{}.json".format(rl_params['env_name']), "wb") as myfile:
    pickle.dump(new_dataset, myfile, protocol=4)
                                                                        
