import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp
from matplotlib import pyplot as plt
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
slim = tf.contrib.slim
tfd = tfp.distributions
session_config = tf.ConfigProto(log_device_placement=False)
session_config.gpu_options.allow_growth = True

from absl import app
from absl import flags
from absl import logging
import gym
from gym import wrappers
from gym.envs.classic_control.pendulum import angle_normalize, PendulumEnv
import d4rl

from tensorflow.nn.rnn_cell import LSTMStateTuple
rnn = tf.contrib.rnn



def main(_):
	# input data
	env = gym.make(FLAGS.env_name)
	# d4rl_original_data = env.get_dataset()
	d4rl_original_data = [i for i in d4rl.sequence_dataset(env)]

	state_dim = d4rl_original_data[0]['observations'].shape[1]
	action_dim = d4rl_original_data[0]['actions'].shape[1]
	is_training = True
	CODE_SIZE = action_dim
	horizon = d4rl_original_data[0]['observations'].shape[0]
	buffer_size = 3000
	MAX_EPISODES = 100 # 100 ~ 2000 according to data size
	MINIBATCH_SIZE = 4 # 4~64 according to data size

	lr = 0.001
	BEST_LOSS = 9999.


	EPS = 1e-8
	save_path = "./saved_dist/state_action_dist.ckpt"

	# get mean and std
	ob = [i for u in d4rl_original_data for j in u['observations'] for i in j]
	obs_mean = sum(ob)/len(ob)
	obs_std = np.std(ob)

	rw = [j for u in d4rl_original_data for j in u['rewards']]
	rew_mean = sum(rw)/len(rw)
	rew_std = np.std(rw)

	class ReplayBuffer_Trajectory(object):
	# read by traj
	    def __init__(self, obs_dim, act_dim, horizon, size):
	        # size is in terms of num. of trajectories
	        self.obs1_buf = np.zeros([size, horizon, obs_dim], dtype=np.float32)
	#         self.obs2_buf = np.zeros([size, horizon, obs_dim], dtype=np.float32)
	        self.acts_buf = np.zeros([size, horizon, act_dim], dtype=np.float32)
	        self.rews_buf = np.zeros([size, horizon], dtype=np.float32)
	        self.done_buf = np.zeros([size, horizon], dtype=np.float32)
	        self.ptr0, self.ptr1, self.size, self.max_size, self.horizon = 0, 0, 0, size, horizon
	        self.count = 0
	        
	    def port_d4rl_data(self, d4rl_data, obs_mean, obs_std, rew_mean, rew_std):
	        """
	        Port d4rl sequence datasets (generator format) into buffer
	        Now only support running this **before training starts**
	        """
	        d4rl_data = [_d for _d in d4rl_data] # convert generator to list
	        
	        d4rl_size = 0
	        for i in range(len(d4rl_data)):
	#             if d4rl_data[i]['observations'].shape[0] == 1000 and d4rl_data[i]['next_observations'].shape[0]==1000 and  d4rl_data[i]['actions'].shape[0] == 1000:
	            d4rl_size += 1
	            
	        if self.max_size < d4rl_size:
	            assert False, "Buffer size smaller than the size of d4rl data, cannot port in"
	        
	        for i in range(len(d4rl_data)):
	#             if d4rl_data[i]['observations'].shape[0] == 1000 and d4rl_data[i]['next_observations'].shape[0] == 1000 and d4rl_data[i]['actions'].shape[0] == 1000:
	            self.obs1_buf[self.ptr0, :, :] = (d4rl_data[i]['observations'].astype(np.float32) - obs_mean) / obs_std
	#             self.obs2_buf[self.ptr0, :, :] = (d4rl_data[i]['next_observations'].astype(np.float32) - obs_mean) / obs_std
	            self.acts_buf[self.ptr0, :, :] = d4rl_data[i]['actions'].astype(np.float32)
	            self.rews_buf[self.ptr0, :] = (d4rl_data[i]['rewards'].astype(np.float32) - rew_mean) / rew_std
	            self.done_buf[self.ptr0, :] = d4rl_data[i]['terminals'].astype(np.float32)
	            self.size = min(self.size+1, self.max_size)
	            self.ptr0 = (self.ptr0+1) % self.max_size
	            self.count += 1

	    def add(self, obs, act, rew, done):
	        self.obs1_buf[self.ptr0, self.ptr1] = obs
	#         self.obs2_buf[self.ptr0, self.ptr1] = next_obs
	        self.acts_buf[self.ptr0, self.ptr1] = act
	        self.rews_buf[self.ptr0, self.ptr1] = rew
	        self.done_buf[self.ptr0, self.ptr1] = done
	        self.ptr1 = (self.ptr1+1) % self.horizon
	        if self.ptr1 == 0:
	            self.size = min(self.size+1, self.max_size)
	            self.ptr0 = (self.ptr0+1) % self.max_size
	            self.count += 1

	    def sample_batch(self, batch_size=32):
	        idxs = np.random.randint(0, self.size, size=batch_size)
	        return dict(obs1=self.obs1_buf[idxs],
	#                     obs2=self.obs2_buf[idxs],
	                    acts=self.acts_buf[idxs],
	                    rews=self.rews_buf[idxs],
	                    done=self.done_buf[idxs],
	                   )
	    
	    def save(self, path):
	        np.savez(
	            path, 
	            obs1_buf=self.obs1_buf, 
	#             obs2_buf=self.obs2_buf, 
	            acts_buf=self.acts_buf, 
	            rews_buf=self.rews_buf, 
	            done_buf=self.done_buf,
	        )

	def trun_normal_log_prob(x, mu, std, low, high):
	    z = tfd.Normal(0,1).cdf((high-x)/(std+EPS)) - tfd.Normal(0,1).cdf((low-x)/(std+EPS))
	    return tf.reduce_sum(-0.5*((x - mu) / (std+EPS))**2 - 0.5*tf.log(2*np.pi) - tf.log(std*z), axis=1, name="log_prob")

	tf.reset_default_graph()

	state_holder = tf.placeholder(shape=[None, state_dim], dtype=tf.float32, name='state_holder') 
	action_holder = tf.placeholder(shape=[None, action_dim], dtype=tf.float32, name='action_holder')


	def learn_dist_from_s(state, code_size, reuse=tf.AUTO_REUSE, is_training=True, var_scope="BC"):
	    with tf.variable_scope(var_scope, reuse=reuse) as scope:
	        with slim.arg_scope([slim.fully_connected], 
	                                activation_fn=tf.nn.relu,
	                                weights_initializer=tf.glorot_uniform_initializer,
	                                weights_regularizer=slim.l2_regularizer(0.001),
	                                biases_regularizer=slim.l2_regularizer(0.001),
	                                normalizer_fn = slim.batch_norm,
	                                normalizer_params = {"is_training": is_training},
	                                reuse = reuse,
	                                scope = scope):
	            # is_training = False for evaluation
	            x = slim.fully_connected(state, 128, scope="fc1")
	            x = slim.fully_connected(x, 64, scope="fc2")
	            loc = slim.fully_connected(x, code_size, activation_fn=None, scope="loc")
	            scale =slim.fully_connected(x, code_size, activation_fn=tf.nn.softplus, scope="scale")
	#             dist = tfd.MultivariateNormalDiag(loc, scale)
	            out_sample = tfd.TruncatedNormal(loc, scale, -1., 1.).sample() # -1, 1 bound
	            out_log_prob = trun_normal_log_prob(action_holder, loc, scale, -1., 1.)
	            return out_sample, out_log_prob

	sample, log_prob = learn_dist_from_s(state_holder, CODE_SIZE)
	# loss = tf.reduce_mean(-dist.log_prob(action_holder))
	loss= tf.reduce_mean(-log_prob)
	optimize = tf.train.AdamOptimizer(lr).minimize(loss)

	saver = tf.train.Saver() # save all variables



	sess = tf.InteractiveSession(config=session_config)
	sess.run(tf.global_variables_initializer())

	replay_buffer = ReplayBuffer_Trajectory(state_dim, action_dim, horizon, buffer_size)
	replay_buffer.port_d4rl_data(d4rl_original_data, obs_mean, obs_std, rew_mean, rew_std, )


	for i in range(MAX_EPISODES):
	    
	    if replay_buffer.size > MINIBATCH_SIZE:
	        batch = replay_buffer.sample_batch(MINIBATCH_SIZE)    
	        (s_batch, a_batch) = (batch["obs1"], batch["acts"],)
	        
	        for _t in range(horizon):
	            if _t == 0:
	                loss_list = []
	            feed_dict={action_holder : a_batch[:, _t, :], state_holder : s_batch[:, _t, :],}
	            loss_val, _ = sess.run([loss, optimize], feed_dict)
	            loss_list += [loss_val]
	        
	        print('epi: {}, loss: {}'.format(i, np.mean(loss_list)))
	        
	        if np.mean(loss_list) < BEST_LOSS:
	            BEST_LOSS = np.mean(loss_list)
	            saver.save(sess, save_path)

if __name__ == '__main__':
	ENV_NAME = 'ProbabilityITS'
	# use normalized data
	with open('../processed_data/{}/train.npy'.format(ENV_NAME), 'rb') as f: 
	    DATA = np.load(f, allow_pickle=True)
	app.run(main)
