import os
import numpy as np
import params as p
import matplotlib.pyplot as plt
import time
import matplotlib.animation as animation
import matplotlib.patches as patches
#from keras.models import Sequential
#from keras.layers import Dense
#from keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

#import priors_tabular as PR

def Qlearn_multirun_tab():
	model = Sequential()
	model.add(Dense(24, input_shape=(3,), activation="relu"))
	model.add(Dense(24, activation="relu"))
	model.add(Dense(1, activation="linear"))
	model.compile(loss="mse", optimizer=Adam(lr=0.001))

	
	#This function just runs multiple instances of 
	#Q-learning. Doing so helps obtain an average performance 
	#measure over multiple runs.
	retlog=[] # log of returns of all episodes, in all runs
	retveclog=[]
	bumpcountlog=[]
	for i in range(p.Nruns):
		print("Run no:",i)
		tracebuffer=[]
		tracebuffer_neg=[]
		Q,ret,betamatrix,bumpcountret,tracemap,model,tracebuffer,tracebuffer_neg,ret_vec,globalQ=main_Qlearning_tab(model,tracebuffer,tracebuffer_neg)#call Q learning
		print("retshape")
		print(np.shape(ret))
		time.sleep(3)
		if i==0:
			retlog=ret
			retveclog=ret_vec
			#bumpcountlog=bumpcountret
		else:
			retlog=np.vstack((retlog,ret))
			retveclog=np.vstack((retveclog,ret_vec))
			#bumpcountlog=np.vstack((bumpcountlog,bumpcountret))
		#retlog.append(ret)
		if (i+1)/p.Nruns==0.25:
			print('25% runs complete')
		elif (i+1)/p.Nruns==0.5:
			print('50% runs complete')
		elif (i+1)/p.Nruns==0.75:
			print('75% runs complete')
		elif (i+1)==p.Nruns:
			print('100% runs complete')
	#meanreturns=(np.mean(retlog,axis=0))
	return Q, retlog,betamatrix,bumpcountlog,tracemap,model,retveclog,globalQ

def main_Qlearning_tab(model,tracebuffer,tracebuffer_neg):
	#This calls the main Q learning algorithm
	Q=np.zeros((p.a,p.b,p.A)) # initialize Q function as zeros
	tracemap=np.zeros((p.a,p.b,p.A))
	betamatrix=np.ones((p.a,p.b,p.A))
	#betamatrix=np.random.rand(p.a,p.b,p.A)
	visitmap=np.zeros((p.a,p.b,p.A))
	#statevisitslog=np.zeros((p.a,p.b)) # initialize counter for visits
	goal_state=p.targ#target point
	returns=[]#stores returns for each episode
	bumpcountret=[]
	ret=[]
	Qimall=[]
	totcnt=0
	success_attempts=0
	success_cnt=0
	tot_attempts=0
	subgoal_targs=[]
	sub_pol=[]
	ret_vec=[]
	globalQ=np.zeros((p.a,p.b,p.A))
	#for i in range(p.episodes):
	i=0
	while len(ret_vec)<p.retvecsize:
		if (i+1)/p.episodes==0.25:
			print('25% episodes done')
		elif (i+1)/p.episodes==0.5:
			print('50% episodes done')
		elif (i+1)/p.episodes==0.75:
			print('75% episodes done')
		elif (i+1)/p.episodes==1:
			print('100% episodes done')
		Q,ret,betamatrix,bumpcount,visitmap,totcnt,tracemap,model,tracebuffer,tracebuffer_neg,success_attempts,tot_attempts,success_cnt,subgoal_targs,sub_pol,ret_vec,globalQ=Qtabular(Q,i,betamatrix,visitmap,totcnt,tracemap,model,tracebuffer,tracebuffer_neg,success_attempts,tot_attempts,success_cnt,subgoal_targs,sub_pol,ret_vec,ret,globalQ)#call Q learning
		print(totcnt)
		if i%1==0:
			returns.append(ret)#compute return offline- can also be done online, but this way, a better estimate can be obtained
			bumpcountret.append(bumpcount)
		i+=1
	print("Successes:"+str(np.sum(ret_vec)))
	print(subgoal_targs)
	time.sleep(10)
	#mapQ(Q)
	#plt.plot(ret_vec)
	#plt.show()
	#time.sleep(5)
	return Q, ret[0:p.retvecsize],betamatrix,bumpcountret,tracemap,model,tracebuffer,tracebuffer_neg,ret_vec[0:p.retvecsize],globalQ

def build_tracebuffer(statelog,tracebuffer,success_flag,tracemap):
	sz=np.shape(statelog)
	if success_flag==1:
		lambd=1
	else:
		lambd=0
	fliplog=np.flipud(statelog)
	for i in range(sz[0]):
		x=fliplog[i][0]-1
		y=fliplog[i][1]-1
		a=fliplog[i][2]
		tracemap[x+1,y+1,a]+=lambd
		if len(tracebuffer)==0:
			tracebuffer=np.array([x,y,a,lambd])
		elif len(tracebuffer)>p.tracebuffersize:
			tracebuffer=np.vstack((tracebuffer,np.array([x,y,a,lambd])))
			tracebuffer=np.delete(tracebuffer,0,axis=0)
		else:
			tracebuffer=np.vstack((tracebuffer,np.array([x,y,a,lambd])))
		lambd=lambd*p.tracedecay
	return tracebuffer,tracemap

def learn_trace_model(model,totcnt,tracebuffer,tracebuffer_neg):

	if len(tracebuffer)>p.tracebuffer_batchsize and len(tracebuffer_neg)>p.tracebuffer_batchsize:
		x=[]
		y=[]
		for jj in range(p.tracebuffer_batchsize):
			if np.random.rand()>0.5:
				curr_buffer=tracebuffer_neg
			else:
				curr_buffer=tracebuffer
			ind=np.random.randint(len(curr_buffer))
			if len(x)==0:
				x=np.array([curr_buffer[ind][0],curr_buffer[ind][1],curr_buffer[ind][2]])
				y=np.array([curr_buffer[ind][3]])
			else:
				x=np.vstack((x,np.array([curr_buffer[ind][0],curr_buffer[ind][1],curr_buffer[ind][2]])))
				y=np.vstack((y,np.array([curr_buffer[ind][3]])))
		#print('model training..')
		#model.fit(x,y+modelcopy.predict(x),verbose=0)
		for i in range(p.N_epochs):
			model.fit(x,y,verbose=0)

	return model#,modelcopy

def explore_trace(model,state):
	traceQ=[]
	for i in range(p.A):
		if i==0:
			traceQ=np.array([state[0],state[1],i])
		else:
			traceQ=np.vstack((traceQ,np.array([state[0],state[1],i])))
	vals=model.predict(traceQ)
	a=np.argmax(vals)
	return a

def Qtabular(Q,episode_no,betamatrix,visitmap,totcnt,tracemap,model,tracebuffer,tracebuffer_neg,success_attempts,tot_attempts,success_cnt,subgoal_targs,sub_pol,ret_vec,ret,globalQ):
	
	initial_state=np.array([2,2])
	rounded_initial_state=staterounding(initial_state)
	while p.world[rounded_initial_state[0],rounded_initial_state[1]]==1:
		initial_state=np.array([(p.a-1)*np.random.random_sample(), (p.b-1)*np.random.random_sample()])
		rounded_initial_state=staterounding(initial_state)
	state=staterounding(initial_state.copy())
	count=0
	breakflag=0
	eps_live=1-(p.epsilon_decay*episode_no)
	#eps_live=0.7
	bumpcount=0
	target_state=p.targ
	#while np.linalg.norm(state-target_state)>p.thresh:
	statelog=[]
	#ret=0
	#statelog.append(state)
	train_prob=1
	for i in range(p.breakthresh):

		if len(ret_vec)>p.retvecsize:
			break
		s_sg=np.shape(subgoal_targs)
		if s_sg[0]>0:
			#execute subgoals till termination
			for kk in range(s_sg[0]):
				goal_temp=subgoal_targs[kk]
				Q_temp=sub_pol[kk]				
				state,ret_vec,ret,Q=execute_pol(Q_temp,goal_temp,state,ret_vec,ret,Q)
			st_state=goal_temp.copy()
			if np.linalg.norm(np.array([state[0],state[1]])-p.targ)<=p.thresh:
				#ret_vec.append(1)
				#ret.append(p.highreward)
				break
			print("exited loop")

			#breakflag=1
		#find subgoal
		s_sg=np.shape(subgoal_targs)
		if s_sg[0]==0:
			st_state=initial_state.copy()
		#else:
		#	st_state=subgoal_targs[s_sg[0]-1]
		if success_attempts>0 and np.linalg.norm(np.array([st_state[0],st_state[1]])-p.targ)>p.thresh:
			#plot_trace(model)
			if np.linalg.norm(np.array([st_state[0],st_state[1]])-p.targ)>p.thresh:
				print("finding subgoal")
				#print(np.array([st_state[0],st_state[1]]))
				Q_sub,sub_goal,ret_vec,ret,Q=findsubgoal(model,st_state,ret_vec,ret,Q)
				print(sub_goal)
				#mapQ(Q_sub)
				if np.linalg.norm(np.array([st_state[0],st_state[1]])-sub_goal)>p.thresh:
					subgoal_targs.append(sub_goal)
					sub_pol.append(Q_sub)
		else:
			count=count+1
			totcnt+=1
			if breakflag==1:
				break
			if count>p.breakthresh:
				breakflag=1
			if eps_live>np.random.sample():
				a=np.random.randint(p.A)
				#a=explore_trace(model,state)
			else:
				Qmax,Qmin,a=maxQ_tab(Q,state)

			next_state=transition(state,a)
			roundedstate=staterounding(state)
			roundednextstate=staterounding(next_state)
			if len(statelog)>=(p.tracelim):
				statelog=statelog[1:(p.tracelim)]
			statelog.append(np.array([state[0],state[1],a]))
			visitmap[roundednextstate[0],roundednextstate[1],a]=visitmap[roundednextstate[0],roundednextstate[1],a]+1
			if p.world[next_state[0],next_state[1]]==0 and (next_state[0]<p.a and next_state[0]>0 and next_state[1]<p.b and next_state[1]>0):	
				if np.linalg.norm(next_state-target_state)<=p.thresh:
					R=p.highreward
					success_attempts+=1
					success_flag=1
					success_cnt=totcnt
					ret_vec.append(1)
					tracebuffer,tracemap=build_tracebuffer(statelog,tracebuffer,success_flag,tracemap)					
				elif np.linalg.norm(next_state-p.NT)<=p.thresh:
					R=p.NTreward
					ret_vec.append(0)
				else:
					R=p.livingpenalty
					ret_vec.append(0)
			else: 
				R=p.penalty
				ret_vec.append(0)
				next_state=state.copy()
				bumpcount=bumpcount+1

			#ret=ret+R
			ret.append(R)
			if success_attempts>0:
				train_prob=p.spike_decay**(totcnt-success_cnt)
				#print(train_prob)
				#time.sleep(1)
				if np.random.sample()<train_prob:
					model=learn_trace_model(model,totcnt,tracebuffer,tracebuffer_neg)
			Qmaxnext,Qminnext,aoptnext=maxQ_tab(Q,next_state)
			Qtarget=R+(p.gamma*Qmaxnext)-Q[roundedstate[0],roundedstate[1],a]
			Q[roundedstate[0],roundedstate[1],a]=Q[roundedstate[0],roundedstate[1],a]+(p.alpha*Qtarget)
			if np.linalg.norm(next_state-target_state)<=p.thresh:
				break
			state=next_state.copy()
	
		if count==p.breakthresh:
			success_flag=0
			tracebuffer_neg,tracemap_neg=build_tracebuffer(statelog,tracebuffer_neg,success_flag,tracemap)
		tot_attempts+=1

	print("episode ended")
	return Q,ret,betamatrix,bumpcount,visitmap,totcnt,tracemap,model,tracebuffer,tracebuffer_neg,success_attempts,tot_attempts,success_cnt,subgoal_targs,sub_pol,ret_vec,globalQ

def execute_pol(sub_pol,subgoal_targ,state,ret_vec,ret,globalQ):
	goal=np.array([subgoal_targ[0],subgoal_targ[1]])
	while np.linalg.norm(state-goal)>p.thresh and len(ret_vec)<p.retvecsize:

		eps_live=1-(len(ret_vec)/p.retvecsize)
		if 0.1>np.random.sample():
			a=np.random.randint(p.A)
		elif eps_live>np.random.sample():
			Qmax,Qmin,a=maxQ_tab(sub_pol,state)
		else:
			Qmax,Qmin,a=maxQ_tab(globalQ,state)

		next_state=transition(state,a)
		roundedstate=staterounding(state)
		roundednextstate=staterounding(next_state)
		if p.world[roundednextstate[0],roundednextstate[1]]==1 or (next_state[0]>=p.a or next_state[0]<=0 or next_state[1]>=p.b or next_state[1]<=0):
			next_state=state.copy()

		if p.world[roundednextstate[0],roundednextstate[1]]==0 and (next_state[0]<p.a and next_state[0]>0 and next_state[1]<p.b and next_state[1]>0):
			if np.linalg.norm(np.array(next_state)-np.array(p.targ))<=p.thresh:
				R_ret=p.highreward
			elif np.linalg.norm(np.array(next_state)-np.array(p.NT))<=p.thresh:
				R_ret=p.NTreward
			else:
				R_ret=p.livingpenalty
		else:
			R_ret=p.penalty
			next_state=state.copy()


		ret.append(R_ret)
		
		Qmaxnext,Qminnext,aoptnext=maxQ_tab(globalQ,next_state)
		Qtarget=R_ret+(p.gamma*Qmaxnext)-globalQ[roundedstate[0],roundedstate[1],a]
		globalQ[roundedstate[0],roundedstate[1],a]=globalQ[roundedstate[0],roundedstate[1],a]+(p.alpha*Qtarget)

		state=next_state.copy()
		#print(state)
		#time.sleep(0.001)
		if np.linalg.norm(next_state-p.targ)<=p.thresh:
			ret_vec.append(1)
			#print(np.sum(ret_vec))
			print('goal found in execute pol')
			break
		else:
			ret_vec.append(0)

	return state,ret_vec,ret,globalQ


def maxQ_tab(Q,state):
	#get max of Q values and corresponding action
	Qlist=[]
	roundedstate=staterounding(state)
	for i in range(p.A):
		Qlist.append(Q[roundedstate[0],roundedstate[1],i])
	tab_maxQ=np.max(Qlist)
	tab_minQ=np.min(Qlist)
	maxind=[]
	for j in range(len(Qlist)):
		if tab_maxQ==Qlist[j]:
			maxind.append(j)
	#print(maxind)
	if len(maxind)>1:
		optact=maxind[np.random.randint(len(maxind))]
	else:
		optact=maxind[0]
	return tab_maxQ,tab_minQ,optact

def optpol_visualize(Qp):
	for i in range(p.a):
		for j in range(p.b):
			if p.world[i,j]==0:
				Qmaxopt,Qminopt,optact=maxQ_tab(Qp,[i,j])
				if optact==0:
					plt.scatter(i,j,color='red')
				elif optact==1:
					plt.scatter(i,j,color='green')
				elif optact==2:
					plt.scatter(i,j,color='blue')
				elif optact==3:
					plt.scatter(i,j,color='yellow')

	plotmap(p.world)
	plt.show()

def transition(state,act):
	#print(orig_state)
	#print(act)
	n1 = np.random.uniform(low=-0.2, high=0.2, size=(1,))# x noise
	n2 = np.random.uniform(low=-0.2, high=0.2, size=(1,))# y noise
	new_state=state.copy()
	if act==0:
		new_state[0]=state[0]
		new_state[1]=state[1]+1#move up
	elif act==1:
		new_state[0]=state[0]+1#move right
		new_state[1]=state[1]
	elif act==2:
		new_state[0]=state[0]
		new_state[1]=state[1]-1#move down
	elif act==3:
		new_state[0]=state[0]-1#move left
		new_state[1]=state[1]
	if new_state[0]>=p.a-1:
		new_state[0]=p.a-2
	elif new_state[0]<=0:
		new_state[0]=1
	if new_state[1]>=p.b-1:
		new_state[1]=p.b-2
	elif new_state[1]<=0:
		new_state[1]=1

	#new_state[0]=new_state[0]+n1
	#new_state[1]=new_state[1]+n2
	return new_state

########Additional functions for visualization######
def plotmap(worldmap):
	#plots the obstacle map
	for i in range(p.a):
		for j in range(p.b):
			if worldmap[i,j]>0:
				plt.scatter(i,j,color='black')
	plt.show()

def staterounding(state):
	#rounds off states
	roundedstate=[0,0]
	roundedstate[0]=int(np.around(state[0]))
	roundedstate[1]=int(np.around(state[1]))
	if roundedstate[0]>=(p.a-1):
		roundedstate[0]=p.a-2
	elif roundedstate[0]<1:
		roundedstate[0]=1
	if roundedstate[1]>=(p.b-1):
		roundedstate[1]=p.b-2
	elif roundedstate[1]<=0:
		roundedstate[1]=1
	return roundedstate

def opt_pol(Q,state,goal_state):
	#shows optimal policy
	plt.figure(0)
	plt.ion()
	for i in range(p.a):
		for j in range(p.b):
			if p.world[i,j]>0:
				plt.scatter(i,j,color='black')
	plt.show()
	pol=[]
	statelog=[]
	count=1
	while np.linalg.norm(state-goal_state)>=1:
		Qm,Qmin,a=maxQ_tab(Q,state)
		if np.random.sample()>0.9:
			a=np.random.randint(p.A)
		next_state=transition(state,a)
		roundednextstate=staterounding(next_state)
		if p.world[roundednextstate[0],roundednextstate[1]]==1:
			next_state=state.copy()
		pol.append(a)
		statelog.append(state)
		print(state)
		plt.ylim(0, p.b)
		plt.xlim(0, p.a)
		plt.scatter(state[0],state[1],(60-count*0.4),color='blue')
		plt.draw()
		plt.pause(0.1)
		state=next_state.copy()
		print(count)
		if count>=100:
			break
		count=count+1
	return statelog,pol

def mapQ(Q):
	#plots a map of the value function
	fig=plt.figure(1)
	plt.ion
	Qmap=np.zeros((p.a,p.b))
	for i in range(p.a):
		for j in range(p.b):
 			Qav=0
 			for k in range(p.A):
 				Qav=Qav+Q[i,j,k]
 			Qmap[i,j]=Qav
	Qfig=plt.imshow(np.rot90(Qmap))
	Qmap=Qmap-np.min(Qmap)
	if np.max(Qmap)>0:
		Qmap=Qmap/np.max(Qmap)
	Qmap=np.rot90(Qmap)
	plt.show()

	return Qmap

def maptrace(Qmap):
	#plots a map of the value function
	fig=plt.figure(1)
	plt.ion
	Qmap=Qmap-np.min(Qmap)
	if np.max(Qmap)>0:
		Qmap=Qmap/np.max(Qmap)
	Qmap=np.rot90(Qmap)

	return Qmap

def findsubgoal(model,init_state,ret_vec,ret,globalQ):
	Q=np.zeros((p.a,p.b,p.A))
	state=init_state
	a=np.random.randint(p.A)
	state_sa=np.array([state[0],state[1],a])
	dummy_state=np.vstack((state_sa,state_sa))
	rew_trace=model.predict(dummy_state)[0][0]
	targ_found=0
	target_state=p.targ
	j=0
	#while targ_found==0 and len(ret_vec)<p.retvecsize:
	for j in range(p.episodes):
		print(j)
		#ret=0

		states_sa=[]
		eps_live=1-(p.epsilon_decay*j)
		rounded_initial_state=staterounding(init_state)
		state=staterounding(init_state.copy())
		count=0
		breakflag=0
		for i in range(p.breakthresh):
			count=count+1
			if breakflag==1:
				break
			if count>p.breakthresh:
				breakflag=1
			#print(i)
			if eps_live>np.random.sample():
				a=np.random.randint(p.A)
			else:
				Qmax,Qmin,a=maxQ_tab(Q,state)
			next_state=transition(state,a)
			roundednextstate=staterounding(next_state)
			if np.linalg.norm(np.array(next_state)-p.targ)<=p.thresh:
				targ_found=1
				target_state=np.array([next_state[0],next_state[1]])
			if i==0:
				states_sa=np.array([state[0],state[1],a])
			if p.world[roundednextstate[0],roundednextstate[1]]==0 and (next_state[0]<p.a and next_state[0]>0 and next_state[1]<p.b and next_state[1]>0):
				if targ_found==1:
					#print(target_state)
					#time.sleep(0.1)
					if np.linalg.norm(np.array(next_state)-np.array(target_state))<=p.thresh:
						R=p.highreward
					else:
						R=p.livingpenalty
				else:
					R=p.livingpenalty
			else:
				R=p.penalty
				next_state=state.copy()
			
			if p.world[roundednextstate[0],roundednextstate[1]]==0 and (next_state[0]<p.a and next_state[0]>0 and next_state[1]<p.b and next_state[1]>0):
				if np.linalg.norm(np.array(next_state)-np.array(p.targ))<=p.thresh:
					R_ret=p.highreward
				elif np.linalg.norm(np.array(next_state)-np.array(p.NT))<=p.thresh:
					R_ret=p.NTreward
				else:
					R_ret=p.livingpenalty
			else:
				R_ret=p.penalty
				next_state=state.copy()
			#print("ret"+str(R_ret))
			ret.append(R_ret)
			roundedstate=staterounding(state)
			Qmaxnext_global,Qminnext_global,aoptnext_global=maxQ_tab(globalQ,next_state)
			Qtarget_global=R_ret+(p.gamma*Qmaxnext_global)-globalQ[roundedstate[0],roundedstate[1],a]
			globalQ[roundedstate[0],roundedstate[1],a]=globalQ[roundedstate[0],roundedstate[1],a]+(p.alpha*Qtarget_global)
			
			if np.linalg.norm(next_state-p.targ)<=p.thresh:
				ret_vec.append(1)
			else:
				ret_vec.append(0)
			if len(ret_vec)>p.retvecsize:
				break
			Qmaxnext,Qminnext,aoptnext=maxQ_tab(Q,next_state)
			Qtarget=R+(p.gamma*Qmaxnext)-Q[roundedstate[0],roundedstate[1],a]
			Q[roundedstate[0],roundedstate[1],a]=Q[roundedstate[0],roundedstate[1],a]+(p.alpha*Qtarget)
			
			if targ_found==1:
				if np.linalg.norm(np.array(next_state)-np.array(target_state))<=p.thresh:
					if np.linalg.norm(p.targ-np.array(target_state))<=p.thresh:
						R_ret=1
					break
			else:
				states_sa=np.vstack((states_sa,np.array([state[0],state[1],a])))
				if i==p.breakthresh-1:
					traces=model.predict(states_sa)
					targ_trace=np.max(traces)
					if targ_trace>rew_trace*p.subgoal_thresh:
						ind=np.argmax(traces)
						max_sa=states_sa[ind]
						targ_found=1
						target_state=np.array([max_sa[0],max_sa[1]])
			state=next_state.copy()
		#j+=1
	print("done")
	#time.sleep(10)		
	return Q,target_state,ret_vec,ret,globalQ

def plot_trace(model):
	states=[]
	cnt=0
	tracepred=np.zeros((p.a,p.b,p.A))
	for i in range(p.a):
		for j in range(p.b):
			for k in range(p.A):
				states=np.vstack((np.array([0,0,0]),np.array([i,j,k])))
				pred=model.predict(states)
				tracepred[i,j,k]=pred[1]
	#tracepred[p.world==1]=-1
	meantrace=np.mean(tracepred,axis=2)
	plt.imshow(np.rot90(meantrace))
	plt.show()


#######################################
if __name__=="__main__":
	#w,Qimall=Qlearn_main_vid()

	Q,retlog,betamatrix,bumpcountlog,tracemap,model,retveclog,globalQ=Qlearn_multirun_tab()

	np.savez("trace_explore"+str(p.Nruns)+"_runs.npy.npz",retlog,bumpcountlog,betamatrix,Q,tracemap,retveclog)
	rvl=np.mean(retveclog,axis=0)
	roll_avg=[]
	for i in range(4000):
		if i>0:
			ind=i*100
			ra=np.mean(rvl[(ind-100):ind])
			roll_avg.append(ra)
	plt.plot(roll_avg)
	plt.show()
	'''
	for i in range(p.a):
		for j in range(p.b):
			if cnt==0:
				states=np.array([i,j])
			else:
				states=np.vstack((states,np.array([i,j])))
			cnt+=1		
	model.predict(states)
	'''