import sys
sys.path.append('./')

import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
from autograd.misc.optimizers import adam
from autograd.scipy.misc import logsumexp
import autograd.scipy.stats as stats
import random
import matplotlib.pyplot as plt 
import seaborn as sns

sys.setrecursionlimit(10**8)

#### Data Generation ####

def init_model_params(Dx, Dy, alpha, r, obs, rs = npr.RandomState(0)):
	mu0 = np.zeros(Dx)
	Sigma0 = np.eye(Dx)
	
	A = np.zeros((Dx,Dx))
	for i in range(Dx):
		for j in range(Dx):
			A[i,j] = alpha**(abs(i-j)+1)
			
	Q = np.eye(Dx)
	C = np.zeros((Dy,Dx))
	if obs == 'sparse':
		C[:Dy,:Dy] = np.eye(Dy)
	else:
		C = rs.normal(size=(Dy,Dx))
	R = r * np.eye(Dy)
	
	return (mu0, Sigma0, A, Q, C, R)
	
def init_prop_params(T, Dx, scale = 0.5, rs = npr.RandomState(0)):
	return [(scale * rs.randn(Dx), # Bias
			 1. + scale * rs.randn(Dx), # Linear times A/mu0
			 scale * rs.randn(Dx)) # Log-var
			for t in range(T)]

def generate_data(model_params, T, rs = npr.RandomState(0)):
	mu0, Sigma0, A, Q, C, R = model_params
	Dx = mu0.shape[0]
	Dy = R.shape[0]
	
	x_true = np.zeros((T,Dx))
	y_true = np.zeros((T,Dy))

	for t in range(T):
		if t > 0:
			x_true[t,:] = rs.multivariate_normal(np.dot(A,x_true[t-1,:]),Q)
		else:
			x_true[0,:] = rs.multivariate_normal(mu0,Sigma0)
		y_true[t,:] = rs.multivariate_normal(np.dot(C,x_true[t,:]),R)
		
	return x_true, y_true
	
def log_marginal_likelihood(model_params, T, y_true):
	mu0, Sigma0, A, Q, C, R = model_params
	Dx = mu0.shape[0]
	Dy = R.shape[1]
	
	log_likelihood = 0.
	xfilt = np.zeros(Dx)
	Pfilt = np.zeros((Dx,Dx))
	xpred = mu0
	Ppred = Sigma0

	for t in range(T):
		if t > 0:
			# Predict
			xpred = np.dot(A,xfilt)
			Ppred = np.dot(A,np.dot(Pfilt,A.T)) + Q

		# Update
		yt = y_true[t,:] - np.dot(C,xpred)
		S = np.dot(C,np.dot(Ppred,C.T)) + R
		K = np.linalg.solve(S,np.dot(C,Ppred)).T
		xfilt = xpred + np.dot(K,yt)
		Pfilt = Ppred - np.dot(K,np.dot(C,Ppred))

		sign, logdet = np.linalg.slogdet(S)
		log_likelihood += -0.5*(np.sum(yt*np.linalg.solve(S,yt)) + logdet + Dy*np.log(2.*np.pi))
		
	return log_likelihood


#### SMC algorithm ####

class smc_prc:
	"""
	
	Class for defining functions used in SMC_PRC algorithm.
	
	"""
	def __init__(self, T, Dx, Dy, N,use_dice_factory,model_params):
		self.T = T
		self.Dx = Dx
		self.Dy = Dy
		self.N = N
		self.use_dice_factory = use_dice_factory
		self.model_params = model_params
		


	def log_transition(self,x_t,x_t_1,A,Q):
		return stats.multivariate_normal.logpdf( (x_t- np.dot(A,x_t_1)),np.zeros((Q.shape[0])),Q)


	def log_emission(self,y_t,x_t,C,R):
		return stats.multivariate_normal.logpdf( (y_t-np.dot(x_t,C.T)),y_t*0,R)

	def log_proposal(self,x_t,x_t_1,A,Q,mu,sig):
		return stats.multivariate_normal.logpdf((x_t-mu-np.dot(A,x_t_1)),np.zeros((sig.shape[0])),sig)

	def generate_proposal(self,x_t_1,A,Q,K,mu,sig):
		sig_chol=np.linalg.cholesky(sig)
		Z=np.random.multivariate_normal(np.zeros((Dx)),np.diag(np.repeat(1,Dx)),[K,1]) # Generates K samples
		x_t = mu +( np.dot(A,x_t_1) ) + ( np.dot(Z[:,0,:],sig_chol) )
		return x_t


	def PRC_step(self,x_t_1,t,c_val,mu,sig):
		mu0, Sigma0, A, Q, C, R = model_params
		if t>0:
			sample = self.generate_proposal(x_t_1,A,Q,5,mu,sig)  # Generate five samples 
			log_fxt_xt_1 = self.log_transition(sample,x_t_1,A,Q)
			log_gyt_xt = self.log_emission(y_true[t,:],sample,C,R)
			log_prop =  self.log_proposal(sample,x_t_1,A,Q,mu,sig)
			logF=(log_fxt_xt_1+log_gyt_xt-log_prop)
			
			score_t= logF + c_val
		
			log_accept_prob_t = -np.log(1+np.exp(-score_t))
			U = np.random.uniform(0,1,5)

			logF = logF[np.log(U)<log_accept_prob_t] 
			sample= sample[np.log(U)<log_accept_prob_t] 
			log_accept_prob_t = log_accept_prob_t[np.log(U)<log_accept_prob_t] 
			if len(logF)==0:
				return self.PRC_step(x_t_1,t,c_val,mu,sig)
			else:
				return ((logF[0]-log_accept_prob_t[0]),sample[0,:])


		else:
			sample = self.generate_proposal(mu0,A,Q,5,mu,sig)  # Generate five samples 
			log_fx1_x0 = self.log_transition(sample,mu0,A,Q)
			log_gy1_x1 = self.log_emission(y_true[t,:],sample,C,R)
			log_prop = self.log_proposal(sample,mu0,A,Q,mu,sig)
			logF=(log_fx1_x0+log_gy1_x1-log_prop)
		
			t1= logF + c_val
			log_accept_prob1 = -np.log(1+np.exp(-t1)) 
			U = np.random.uniform(0,1,5)

			logF = logF[np.log(U)<log_accept_prob1] 
			sample= sample[np.log(U)<log_accept_prob1] 
			log_accept_prob1 = log_accept_prob1[np.log(U)<log_accept_prob1]
			
			if len(logF)==0:
				return self.PRC_step(x_t_1,t,c_val,mu,sig)	
			else:
				return ( (logF[0]-log_accept_prob1[0]),sample[0,:])



	def dice_factory(self,Unnorm_weights,logW,x_t_1,t,c_val,mu0,A,Q,C,R,mu,sig):
		beta_prob = 0.999
		u_beta =np.random.uniform(0,1,1)
		if u_beta<beta_prob:
			C_t=random.choices(population=range(N), weights=np.exp(Unnorm_weights-logsumexp(Unnorm_weights) ),k=1 )
			C_t=C_t[0]
			if t==0:
				x1 = self.generate_proposal(mu0,A,Q,1,mu,sig)   # Generate one sample
				log_fx1_x0 = self.log_transition(x1,mu0,A,Q)
				log_gy1_x1 = self.log_emission(y_true[t,:],x1,C,R)
				log_prop = self.log_proposal(x1,mu0,A,Q,mu,sig)
				logF=(log_fx1_x0+log_gy1_x1-log_prop)
			
				t1= logF + c_val
		
				log_accept_prob1 = -np.log(1+np.exp(-t1))

				U = np.random.uniform(0,1,1)
				if np.log(U)<log_accept_prob1:
					return C_t
				else:
					return self.dice_factory(Unnorm_weights,logW,[],t,c_val,mu0,A,Q,C,R,mu,sig)
			else:
				x_t = self.generate_proposal(x_t_1[C_t,:],A,Q,1,mu,sig)   # Generate 1 sample
				log_fxt_xt_1 = self.log_transition(x_t,x_t_1[C_t,:],A,Q)
				log_gyt_xt = self.log_emission(y_true[t,:],x_t,C,R)
				log_prop =  self.log_proposal(x_t,x_t_1[C_t,:],A,Q,mu,sig)
				logF=(log_fxt_xt_1+log_gyt_xt-log_prop)
			
				score_t= logF + c_val
				log_accept_prob_t = -np.log(1+np.exp(-score_t))

				U = np.random.uniform(0,1,1)
				if np.log(U)<log_accept_prob_t:
					return C_t
				else:
					return self.dice_factory(Unnorm_weights,logW,x_t_1,t,c_val,mu0,A,Q,C,R,mu,sig)

		else:
			C_t=random.choices(population=range(N), weights=np.exp(logW-logsumexp(logW) ),k=1 )
			return C_t[0]


	def T_value(self,x,gamma):
		mu=x[0:Dx]
		sig=np.diag(np.exp(x[Dx:(2*Dx)]) )
		mu0, Sigma0, A, Q, C, R = model_params
		logW=np.zeros((T,N))
		log_Unnormalized_W = np.zeros((T,N))
		Particles = np.zeros((T,N,Dx))
		X=np.zeros((T-1,N,Dx))
		T_value=np.zeros((T,N))
		K1 = 5
		for t in range(T):
			if t>0:
				x_t_1=X[t-1,:,:] 
				for n in range(N):
					x_t = self.generate_proposal(x_t_1[n,:],A,Q,K1,mu,sig)    # Generate K1 samples
					log_fxt_xt_1 = self.log_transition(x_t,x_t_1[n,:],A,Q)
					log_gyt_xt = self.log_emission(y_true[t,:],x_t,C,R)
					log_prop =  self.log_proposal(x_t,x_t_1[n,:],A,Q,mu,sig)
					logF=(log_fxt_xt_1+log_gyt_xt-log_prop)
					T_value[t,n]=np.quantile(-logF,gamma)

						
					score_t= logF + T_value[t,n] 
					log_accept_prob_t = -np.log(1+np.exp(-score_t))
					log_Z_R_t=logsumexp(log_accept_prob_t-np.log(K))

					log_unnorm_weight,sample = self.PRC_step(x_t_1[n,:],t,T_value[t,n],mu,sig)
					Particles[t,n,:] = sample
					logW[t,n] = log_unnorm_weight + log_Z_R_t
					log_Unnormalized_W[t,n] = log_unnorm_weight
				

				if t<=(T-2):
					if self.use_dice_factory==1:
						z1_new=[]
						n=0
						while n<N:
							index = self.dice_factory(log_Unnormalized_W[t,:],logW[t,:],x_t_1,t,T_value[t,n],mu0,A,Q,C,R,mu,sig)
							z1_new.append(index)
							n=n+1

						z1_new = np.array(z1_new).reshape((N,1))  # Ancestor indices
						x_t_1 = Particles[t,z1_new,:].reshape((N,Dx)) 
						X[t,:,:] = x_t_1 

					else:
						z1_new= np.array(random.choices(population=range(N), weights=np.exp(logW[t,:]-logsumexp(logW[t,:]) ),k=N )).reshape((N,1))
						x_t_1 = Particles[t,z1_new,:].reshape((N,Dx))
						X[t,:,:] = x_t_1



			else:
				for n in range(N):

					x1 = self.generate_proposal(mu0,A,Q,K1,mu,sig)   # Generate K1 samples
					log_fx1_x0 = self.log_transition(x1,mu0,A,Q)
					log_gy1_x1 = self.log_emission(y_true[t,:],x1,C,R)
					log_prop = self.log_proposal(x1,mu0,A,Q,mu,sig)
					logF=(log_fx1_x0+log_gy1_x1-log_prop)
					T_value[t,n]=np.quantile(-logF,gamma)

					t1= logF + T_value[t,n] 
				
					log_accept_prob1 = -np.log(1+np.exp(-t1))
					log_Z_R_1=logsumexp(log_accept_prob1-np.log(K))

					log_unnorm_weight,sample = self.PRC_step([],0,T_value[t,n],mu,sig)
					Particles[t,n,:] = sample
					logW[t,n] = log_unnorm_weight + log_Z_R_1
					log_Unnormalized_W[t,n] = log_unnorm_weight
					
				   
				   
				
				if self.use_dice_factory==1:
					z1_new=[]
					n=0
					while n<N:
						index = self.dice_factory(log_Unnormalized_W[t,:],logW[t,:],[],t,T_value[t,n] ,mu0,A,Q,C,R,mu,sig)
						z1_new.append(index)
						n=n+1

					z1_new = np.array(z1_new).reshape((N,1))  # Ancestor indices
					x_t_1 = Particles[t,z1_new,:].reshape((N,Dx))
					X[t,:,:] = x_t_1

				else:
					z1_new= np.array(random.choices(population=range(N), weights=np.exp(logW[t,:]-logsumexp(logW[t,:]) ),k=N )).reshape((N,1))
					x_t_1 = Particles[t,z1_new,:].reshape((N,Dx))
					X[t,:,:] = x_t_1





		return T_value



	def VI_Loss(self,x,T_value):
		mu=x[0:Dx]
		sig=np.diag(np.exp(x[Dx:(2*Dx)]) )
		mu0, Sigma0, A, Q, C, R = model_params
		loss=0.0

		for t in range(T):
			if t>0:
				wt=[]
				Particles_t = []
				log_Unnormalized_W = []
				for n in range(N):
				
					x_t = self.generate_proposal(x_t_1[n,:],A,Q,K,mu,sig)    # Generate K sample
					log_fxt_xt_1 = self.log_transition(x_t,x_t_1[n,:],A,Q)
					log_gyt_xt = self.log_emission(y_true[t,:],x_t,C,R)
					log_prop =  self.log_proposal(x_t,x_t_1[n,:],A,Q,mu,sig)
					logF=(log_fxt_xt_1+log_gyt_xt-log_prop)
						
					score_t= logF + T_value[t,n] 
					log_accept_prob_t = -np.log(1+np.exp(-score_t))
					log_Z_R_t=logsumexp(log_accept_prob_t-np.log(K))

					log_unnorm_weight,sample = self.PRC_step(x_t_1[n,:],t,T_value[t,n],mu,sig)
					Particles_t.append( sample )
					
					log_Unnormalized_W.append(log_unnorm_weight)
					wt.append(log_unnorm_weight + log_Z_R_t)

				loss=loss+logsumexp(np.array(wt)-np.log(N))
				log_Unnormalized_W = np.array(log_Unnormalized_W)
				Particles_t = np.array(Particles_t)
				wt = np.array(wt)
				
				if t<=(T-2):
					if self.use_dice_factory==1:
						z1_new=[]
						n=0
						while n<N:
							index = self.dice_factory(log_Unnormalized_W,wt,x_t_1,t,T_value[t,n],mu0,A,Q,C,R,mu,sig)
							z1_new.append(index)
							n=n+1

						z1_new = np.array(z1_new).reshape((N,1))  # Ancestor indices
						x_t_1 = Particles_t[z1_new,:].reshape((N,Dx)) 
					

					else:
						z1_new= np.array(random.choices(population=range(N), weights=np.exp(wt-logsumexp(wt) ),k=N )).reshape((N,1))
						x_t_1 = Particles_t[z1_new,:].reshape((N,Dx))
					


			else:
				w1=[]
				Particles_1 = []
				log_Unnormalized_W = []
				for n in range(N):
					
					x1 = self.generate_proposal(mu0,A,Q,K,mu,sig)   # Generate K sample
					log_fx1_x0 = self.log_transition(x1,mu0,A,Q)
					log_gy1_x1 = self.log_emission(y_true[t,:],x1,C,R)
					log_prop = self.log_proposal(x1,mu0,A,Q,mu,sig)
					logF=(log_fx1_x0+log_gy1_x1-log_prop)
					
					t1= logF + T_value[t,n] 
				
					log_accept_prob1 = -np.log(1+np.exp(-t1))
					log_Z_R_1=logsumexp(log_accept_prob1-np.log(K))

					log_unnorm_weight,sample = self.PRC_step([],0,T_value[t,n],mu,sig)
					Particles_1.append(sample)
					log_Unnormalized_W.append(log_unnorm_weight)
					w1.append(log_unnorm_weight + log_Z_R_1)


				loss=loss+logsumexp(np.array(w1)-np.log(N))
				log_Unnormalized_W = np.array(log_Unnormalized_W)
				Particles_1 = np.array(Particles_1)
				w1 = np.array(w1)
				

			
				if self.use_dice_factory==1:
					z1_new=[]
					n=0
					while n<N:
						index = self.dice_factory(log_Unnormalized_W,w1,[],t,T_value[t,n] ,mu0,A,Q,C,R,mu,sig)
						z1_new.append(index)
						n=n+1

					z1_new = np.array(z1_new).reshape((N,1))  # Ancestor indices
					x_t_1 = Particles_1[z1_new,:].reshape((N,Dx))
					

				else:
					z1_new= np.array(random.choices(population=range(N), weights=np.exp(w1-logsumexp(w1) ),k=N )).reshape((N,1))
					x_t_1 = Particles_1[z1_new,:].reshape((N,Dx))
					


				
		return (loss)	


T = int(sys.argv[1])
Dx = int(sys.argv[2])
Dy = int(sys.argv[3])
sparse = int(sys.argv[4])
gamma = float(sys.argv[5])
N = int(sys.argv[6])
K = int(sys.argv[7])

alpha = 0.42
r = 1.0 #.1
if sparse==1:
	obs = 'sparse'
else:
	obs = 'parse'

data_seed = npr.RandomState(0)
model_params = init_model_params(Dx, Dy, alpha, r, obs, data_seed)
	
print("Generating data...")
x_true, y_true = generate_data(model_params, T, data_seed)

print('y',y_true)
	
lml = log_marginal_likelihood(model_params, T, y_true)
print("True log-marginal likelihood: "+str(lml))
	
seed = npr.RandomState(0)



#### PRC-SMC with Bernoulli factory ####


lgss_smc_obj = smc_prc(T=T,Dx= Dx,Dy= Dy,N= N,use_dice_factory=1,model_params=model_params)
new_grad=grad(lgss_smc_obj.VI_Loss)


param=np.zeros((1,2*Dx))
param[:,0:Dx]=np.random.normal(0,0.1,[Dx])
param[:,Dx:(2*Dx)]=-10
T_value=np.inf*np.ones((T,N))
Loss_our=[]
m1 = 0
m2 = 0
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
t = 0
learning_rate =0.1
for epoch in range(500):
	

	if (epoch+1)%50==0:
		learning_rate=max(0.001,learning_rate/2.0)

	t += 1

	if (epoch+1)%10==0:
		T_value=lgss_smc_obj.T_value(param[0],gamma)

	

	if (epoch)%20==0:
		loss_val=0.0
		for j in range(10):
			loss_val=lgss_smc_obj.VI_Loss(param[0],T_value)+ loss_val
		print('our',loss_val/10.0)
		Loss_our.append(loss_val/10.0)


	gradient=new_grad(param[0],T_value)
	m1 = beta1 * m1 + (1 - beta1) * gradient
	m2 = beta2 * m2 + (1 - beta2) * gradient**2
	m1_hat = m1 / (1 - beta1**t)
	m2_hat = m2 / (1 - beta2**t)
	param += learning_rate * m1_hat / (np.sqrt(m2_hat) + epsilon)




plt.plot(range(len(Loss_our)),lml*np.ones((len(Loss_our))) ,label=r'log $p_{\theta}(x_{1:T})$',linewidth=4)
plt.plot(range(len(Loss_our)),Loss_our,label='vsmc-prc',linewidth=4)




#### VSMC Method

K=1
lgss_smc_obj = smc_prc(T=T,Dx= Dx,Dy= Dy,N= N,use_dice_factory=0,model_params=model_params)
new_grad=grad(lgss_smc_obj.VI_Loss)


param=np.zeros((1,2*Dx))
param[:,0:Dx]=np.random.normal(0,0.1,[Dx])
param[:,Dx:(2*Dx)]=-10
T_value=np.inf*np.ones((T,N))
Loss_vi=[]
m1 = 0
m2 = 0
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
t = 0
learning_rate =0.1
for epoch in range(500):
	

	if (epoch+1)%50==0:
		learning_rate=max(0.001,learning_rate/2.0)

	t += 1
	gradient=new_grad(param[0],T_value)

	if (epoch)%20==0:
		loss_val=0.0
		for j in range(10):
			loss_val=lgss_smc_obj.VI_Loss(param[0],T_value)+ loss_val
		print('vsmc',loss_val/10.0)
		Loss_vi.append(loss_val/10.0)

	m1 = beta1 * m1 + (1 - beta1) * gradient
	m2 = beta2 * m2 + (1 - beta2) * gradient**2
	m1_hat = m1 / (1 - beta1**t)
	m2_hat = m2 / (1 - beta2**t)
	param += learning_rate * m1_hat / (np.sqrt(m2_hat) + epsilon)



plt.plot(range(len(Loss_vi)),Loss_vi,label='vsmc',linewidth=4)
if sparse == 1:
	plt.title( r'($d_{z} = $'+str(Dx)+r'$, d_{x} = $'+ str(Dy)+ ', C sparse)',fontsize=20)

else:	
	plt.title(r'($d_{z} = $'+str(Dx)+r'$, d_{x} = $'+ str(Dy)+ ', C dense)',fontsize=20)

plt.xlabel('Iterations(500)',fontsize=20)
plt.legend( fontsize = 20)
plt.savefig('fig'+str(Dx)+str(Dy))
np.savetxt('vsmc_prc'+str(Dx)+str(Dy)+str(gamma),Loss_our)
np.savetxt('vsmc'+str(Dx)+str(Dy),Loss_vi)