import os

import numpy as np
from scipy.linalg import solve_triangular
from scipy.optimize import minimize_scalar

server=os.path.isdir("/code")
if server: weights0=np.load(open("/path-to-data/weights.npy","rb"))
else: weights0=np.load("weights.npy")

class LSVI():

	def __init__(self,env,state_dim,action_dim,feature_dim,action_space,K,lamda,gamma,eta,soft=0,rho=None,div=None):
		self.env=env
		self.state_dim=state_dim
		self.action_dim=action_dim
		self.feature_dim=feature_dim
		self.action_space=action_space

		self.state,_=self.env.reset()
		self.episode=0
		self.train_reward=0
		self.phi_cache=None

		self.K=K
		self.lamda=lamda  # ridge regression parameter
		self.gamma=gamma  # discount factor
		self.eta=eta  # step size policy update
		self.H=1/(1-gamma)
		self.soft=soft  # soft policy
		self.rho=rho  # robust parameter
		self.div=div  # robust divergence

		if self.rho is not None:
			assert self.div in ["TV","KL","Chi2"]

		self.at=0
		self.phi_matrix=np.zeros((K,feature_dim))  # store phi
		self.reward_matrix=np.zeros(K)  # store reward
		self.phi_next=np.zeros((K,len(action_space),feature_dim))  # store V_next
		self.done_matrix=np.zeros(K,dtype=bool)  # store done

		self.Lambda=np.zeros((feature_dim,feature_dim))  # Lambda matrix
		self.Lambda_dec=np.zeros((feature_dim,feature_dim))  # Cholesky decomposition
		self.weights=np.zeros(feature_dim)

		# initialize Lambda matrix
		self.Lambda=lamda*np.eye(feature_dim)
		self.Lambda_dec=np.linalg.cholesky(self.Lambda)

	def insert(self,phi,reward,phi_next,done):
		self.phi_matrix[self.at]=phi
		self.reward_matrix[self.at]=reward
		self.phi_next[self.at]=phi_next
		self.done_matrix[self.at]=done
		self.at+=1

	def get_feature(self,state,feature_func):
		return np.array([feature_func(state,a) for a in self.action_space])

	def get_policy(self,phi_full):
		Q_h=phi_full@weights0
		exp_Q=np.exp(0.6*(Q_h-np.max(Q_h)))
		prob=exp_Q/np.sum(exp_Q)
		return prob

	def get_action(self,phi_full):
		Q_h=phi_full@self.weights
		if not self.soft:
			action=np.argmax(Q_h)
		elif self.soft==1:
			exp_Q=np.exp(self.eta*(Q_h-np.max(Q_h)))
			prob=exp_Q/np.sum(exp_Q)
			action=np.random.choice(len(self.action_space),p=prob)
		elif self.soft==2:
			exp_Q=self.get_policy(phi_full)*np.exp(self.eta*(Q_h-np.max(Q_h)))
			prob=exp_Q/np.sum(exp_Q)
			action=np.random.choice(len(self.action_space),p=prob)
		elif self.soft==3:
			eta_Q=self.eta*Q_h
			if np.abs(eta_Q[0]-eta_Q[1])>1:
				action=np.argmax(eta_Q)
			else:
				prob=np.array([(eta_Q[0]-eta_Q[1]+1)/2,(eta_Q[1]-eta_Q[0]+1)/2])
				action=np.random.choice(len(self.action_space),p=prob)
		else:
			exit(-1)
		return self.action_space[action],phi_full[action]

	def get_action_test(self,phi_full):
		Q_h=phi_full@self.weights
		if self.soft==2:
			exp_Q=self.get_policy(phi_full)*np.exp(self.eta*(Q_h-np.max(Q_h)))
			prob=exp_Q/np.sum(exp_Q)
			action=np.argmax(prob)
		else:
			action=np.argmax(Q_h)
		return self.action_space[action]

	def calc_matrix(self):
		phi_matrix=self.phi_matrix[:self.at]
		r_matrix=self.reward_matrix[:self.at]
		not_done=~self.done_matrix[:self.at]
		V_matrix=np.zeros(self.at)

		phi_next=self.phi_next[:self.at][not_done]
		Q_next=np.maximum(np.minimum(phi_next@self.weights,self.H),0)
		if not self.soft:
			V_matrix[not_done]=np.max(Q_next,axis=1)
		elif self.soft==1:
			exp_Q=np.exp(self.eta*(Q_next-np.max(Q_next,axis=1,keepdims=True)))
			prob=exp_Q/np.sum(exp_Q,axis=1,keepdims=True)
			V_matrix[not_done]=np.sum(prob*Q_next,axis=1)-np.sum(prob*np.log(2*prob),axis=1)/self.eta
		elif self.soft==2:
			prob0=self.get_policy(phi_next)
			exp_Q=prob0*np.exp(self.eta*(Q_next-np.max(Q_next,axis=1,keepdims=True)))
			prob=exp_Q/np.sum(exp_Q,axis=1,keepdims=True)
			V_matrix[not_done]=np.sum(prob*Q_next,axis=1)-np.sum(prob*np.log(prob/prob0),axis=1)/self.eta
		elif self.soft==3:
			eta_Q=self.eta*Q_next
			prob=np.maximum(np.minimum((eta_Q[:,0]-eta_Q[:,1]+1)/2,1),0)
			prob=np.stack((prob,1-prob),axis=1)
			V_matrix[not_done]=np.sum(prob*Q_next,axis=1)-np.sum(prob**2,axis=1)/(2*self.eta)
		else:
			exit(-1)
		return phi_matrix,r_matrix,V_matrix

	def cholupdate(self,L,x,sign):
		for k in range(self.feature_dim):
			if sign:
				r=np.sqrt(L[k,k]**2+x[k]**2)
			else:
				r=np.sqrt(L[k,k]**2-x[k]**2)
			c=r/L[k,k]
			s=x[k]/L[k,k]
			L[k,k]=r
			if sign:
				L[k+1:,k]=(L[k+1:,k]+s*x[k+1:])/c
			else:
				L[k+1:,k]=(L[k+1:,k]-s*x[k+1:])/c
			x[k+1:]=c*x[k+1:]-s*L[k+1:,k]

	def estimate_w(self):
		phi=self.phi_matrix[self.at-1]
		self.Lambda+=np.outer(phi,phi)
		self.cholupdate(self.Lambda_dec,phi.copy(),True)
		if not (self.at+1)%1000:
			self.Lambda_dec=np.linalg.cholesky(self.Lambda)

		phi_matrix,r_matrix,V_matrix=self.calc_matrix()
		Y=solve_triangular(self.Lambda_dec,phi_matrix.T,lower=True,check_finite=False)
		xi=solve_triangular(self.Lambda_dec.T,Y,lower=False,check_finite=False)

		if self.rho is None:
			self.weights=xi@(r_matrix+self.gamma*V_matrix)
		else:
			nu=np.zeros(self.feature_dim)

			for i in range(self.feature_dim):
				match self.div:
					case "TV":
						nu[i]=xi[i]@np.minimum(V_matrix,self.rho)
					case "KL":
						tmp=xi[i]@np.exp(-V_matrix/self.rho)
						nu[i]=-self.rho*np.log(np.maximum(tmp,1e-4))
					case "Chi2":
						def fun(alpha):
							z=xi[i]@np.minimum(V_matrix,alpha)
							z_2=xi[i]@np.minimum(V_matrix,alpha)**2
							return -(z+np.minimum(z**2-z_2,0)/(4*self.rho))

						res=minimize_scalar(fun,bounds=(0,self.H))
						nu[i]=-res.fun

			self.weights=xi@r_matrix+self.gamma*nu

	def explore(self,feature_func):
		if self.phi_cache is None:
			phi_step=self.get_feature(self.state,feature_func)
		else:
			phi_step=self.phi_cache
		action,phi=self.get_action(phi_step)
		state_next,reward,terminated,truncated,_=self.env.step(action)
		phi_next=self.get_feature(state_next,feature_func)
		self.train_reward+=reward
		self.insert(phi,reward,phi_next,terminated)
		if terminated or truncated:
			train_reward=self.train_reward
			self.episode+=1
			self.train_reward=0
			self.state,_=self.env.reset()
			self.phi_cache=None
			return train_reward
		else:
			self.state=state_next
			self.phi_cache=phi_next
			return None

	def train(self,feature_func):
		for k in range(self.K):
			self.explore(feature_func)
			self.estimate_w()
