import pickle,time,sys,os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import gym
from replay_buffer import ReplayBuffer
import hashlib
import matplotlib.pyplot as plt
from option_env import OptionEnv

    
def one_hot(n,a):
    x=np.zeros(n,dtype='float32')
    x[a]=1
    return x


def new_model(n_features,n_actions,n_quantiles):
    inputs=layers.Input(shape=n_features)
    dense=layers.Dense(50*n_features,activation='relu')(inputs)
    dense=layers.Dense(50*n_features,activation='relu')(dense)
    outputs=layers.Dense(n_actions*n_quantiles)(dense)
    return keras.Model(inputs=inputs,outputs=outputs)

    
class DQR4_Policy:
    def __init__(self,q_alpha,gamma,seed,n_obs,n_actions,n_quantiles,n_cvar):
        tf.random.set_seed(seed)
        self.gamma=gamma
        self.model=new_model(n_obs,n_actions,n_quantiles)
        self.target_model=new_model(n_obs,n_actions,n_quantiles)
        self.update_target()
        self.n_actions=n_actions
        self.n_quantiles=n_quantiles
        self.n_cvar=tf.constant( n_cvar,dtype='int64' )
        self.q_levels=tf.linspace(0.5/n_quantiles,1-0.5/n_quantiles,n_quantiles)
        self.kappa=tf.constant(1.0)
        self.q_optimizer=tf.keras.optimizers.Adam(learning_rate=q_alpha)
        self.clip_norm=tf.constant(10.0)
        self.target_update_freq=500

    def update_target(self):
        for var,var_target in zip(self.model.trainable_variables,self.target_model.trainable_variables):
            var_target.assign(var)
            
    def extract_state_feature(self,obs):
        return obs.astype('float32')

    @tf.function
    def eval_model(self,phi,cutoff):
        quantiles=tf.reshape( self.model(phi[np.newaxis,:]), [self.n_actions,self.n_quantiles] )
        q=tf.reduce_mean( tf.math.minimum( quantiles - cutoff, 0 ), axis=1 )
        return tf.argmax(q),q,quantiles

    def query(self,phi,cutoff):
        a,q,_=self.eval_model(phi,tf.constant(cutoff,dtype='float32'))
        return a.numpy(),q.numpy()
    
    def query_all(self,phi,cutoff):
        a,q,quantiles=self.eval_model(phi,tf.constant(cutoff,dtype='float32'))
        return a.numpy(),q.numpy(),quantiles.numpy()

    @tf.function
    def eval_model_init(self,phi):
        quantiles = tf.reshape(self.model(phi[np.newaxis,:]),[self.n_actions,self.n_quantiles])
        q=tf.reduce_mean( quantiles[:,:self.n_cvar], axis=1 )
        a=tf.argmax(q)
        cutoff=quantiles[a,self.n_cvar-1]
        return a,cutoff,q
    
    def query_init(self,phi):
        a,cutoff,q=self.eval_model_init(phi)
        return a.numpy(),cutoff.numpy(),q.numpy()

    def eval_policy(self,env,n_rep,max_episode_len,gamma=1.0,render=False,s_step=None):
        n_actions=env.action_space.n
        ep_rewards=[]
        decs=[]
        for rep_i in range(n_rep):
            s=env.reset(s_step)
            if render:
                env.render()
            phi=self.extract_state_feature(s)
            ep_r=0
            steps=0
            a,cutoff,_=self.query_init(phi)
            while steps<max_episode_len:
                decs.append(one_hot(n_actions,a))
                next_s,r,done,_=env.step(a)
                if render:
                    env.render()
                ep_r+=(gamma**steps)*r
                phi=self.extract_state_feature(next_s)
                steps+=1
                if done:
                    break
                cutoff=(cutoff-r)/self.gamma
                a,_=self.query(phi,cutoff)
            ep_rewards.append(ep_r)
        return ep_rewards,np.array(decs)
                      
    @tf.function
    def update_q(self,phis,acts,rewards,next_phis,dones):
        with tf.GradientTape() as tape:
            quantiles=tf.reshape( self.model(phis), [-1,self.n_actions,self.n_quantiles] )
            q=tf.math.reduce_sum(acts[:,:,tf.newaxis]*quantiles, axis=1)
            cutoff=(q[:,self.n_cvar-1]-rewards)/self.gamma
            qn=tf.reshape( self.target_model(next_phis), [-1,self.n_actions,self.n_quantiles] )
            qn_acts=tf.math.argmax( tf.math.reduce_mean( tf.math.minimum(qn-cutoff[:,tf.newaxis,tf.newaxis],0), axis=2 ), axis=1 )
            next_quantiles=tf.gather_nd(qn, qn_acts[:,tf.newaxis], batch_dims=1)
            target=tf.stop_gradient( rewards[:,tf.newaxis]+self.gamma*(1-dones[:,tf.newaxis])*next_quantiles )
            tau=tf.reshape(tf.broadcast_to(self.q_levels,[len(q),self.n_quantiles,self.n_quantiles]), [-1])
            target=tf.reshape(tf.broadcast_to(target[:,:,tf.newaxis],[len(q),self.n_quantiles,self.n_quantiles]), [-1])
            q=tf.reshape(tf.broadcast_to(q[:,tf.newaxis,:],[len(q),self.n_quantiles,self.n_quantiles]), [-1])
            u=target-q
            q_loss=tf.reduce_mean( (tau-(1-tf.math.sign(u))*0.5) * u )
            
        grad=tape.gradient(q_loss,self.model.trainable_variables)
        grad,g_norm=tf.clip_by_global_norm(grad,self.clip_norm)
        self.q_optimizer.apply_gradients(zip(grad,self.model.trainable_variables))
        return grad,q_loss
        
    def train(self,epsilon,buffer_size,batch_size,seed,env,val_env,tst_env,max_iter,explore_iter,eval_freq,max_episode_len):
        q_losses=[]
        q_vals=[]
        grad_norms=[]
        n_rep=1
        res=[]
        if seed is not None:
            np.random.seed(seed)
            env.seed(seed)
            val_env.seed(seed+1)
            tst_env.seed(seed+2)
            self.replay_buf=ReplayBuffer(buffer_size)
            
            self.iter=0
            self.steps=0
            self.ep_done=0
            self.ep_rewards=[0.0]
            
            r_val,d_val=self.eval_policy(val_env,n_rep,max_episode_len,s_step=0)
            r_tst,d_tst=self.eval_policy(tst_env,n_rep,max_episode_len,s_step=0)
            res.append(np.concatenate(([self.iter,0,0,0,np.mean(self.ep_rewards[-100:]),
                                        np.mean(r_val),np.mean(r_tst)],np.mean(d_val,axis=0),np.mean(d_tst,axis=0))))
            print(res[-1])
            
            self.phi=self.extract_state_feature(env.reset())
            self.a_next,self.cutoff,self.last_q=self.query_init(self.phi)
        else:
            self.ep_rewards=self.ep_rewards[-100:].copy()
            
        st=time.time()
        while True:
            ##query policy
            a=self.a_next
            q_vals.append(self.last_q)
            p_explore=1-min(1,self.iter/explore_iter)*(1-epsilon)
            if np.random.rand()<p_explore:
                a=np.random.randint(env.action_space.n)
            obs,r,done,_=env.step(a)
            next_phi=self.extract_state_feature(obs)
            self.ep_rewards[-1]+=r
            self.steps+=1
            self.cutoff=(self.cutoff-r)/self.gamma
            self.a_next,self.last_q=self.query(next_phi,self.cutoff)
            self.replay_buf.add(self.phi,one_hot(env.action_space.n,a),np.array(r,dtype='float32'),next_phi,
                                np.array(done,dtype='float32'))
            if done or self.steps>=max_episode_len:
                self.phi=self.extract_state_feature(env.reset())
                self.steps=0
                self.ep_done+=1
                self.ep_rewards.append(0.0)
                self.a_next,self.cutoff,self.last_q=self.query_init(self.phi)
            else:
                self.phi=next_phi

            #q update
            phis,acts,rewards,next_phis,dones=self.replay_buf.sample(batch_size)
            grad,q_loss=self.update_q(tf.constant(phis),tf.constant(acts),tf.constant(rewards),tf.constant(next_phis),
                                      tf.constant(dones))
            q_losses.append(q_loss.numpy())
            grad_norms.append(  np.sqrt(np.sum([np.sum(g**2) for g in grad])) )
            
            self.iter+=1
            if self.iter%self.target_update_freq==0:
                self.update_target()
            if self.iter%eval_freq==0:
                r_val,d_val=self.eval_policy(val_env,n_rep,max_episode_len,s_step=0)
                r_tst,d_tst=self.eval_policy(tst_env,n_rep,max_episode_len,s_step=0)
                res.append(np.concatenate(([self.iter,np.median(grad_norms),np.max(grad_norms),
                                            np.mean(q_losses),np.mean(self.ep_rewards[-100:]),
                                            np.mean(r_val),np.mean(r_tst)],np.mean(d_val,axis=0),np.mean(d_tst,axis=0))))
                q_losses=[]
                grad_norms=[]
                
                print(self.ep_done,'({:.3f} secs)'.format(time.time()-st),res[-1],np.max(np.array(q_vals),axis=0),
                      np.min(np.array(q_vals),axis=0),
                      np.sum([ex[4] for ex in self.replay_buf._storage]),flush=True)
                q_vals=[]
                st=time.time()
            if self.iter>=max_iter:
                break
        return np.array(res)

