"""
Based on 
https://github.com/aa14k/Exploration-in-RL.git
"""

import numpy as np

class LSVI_PHE(object):
    def __init__(self, env, K, M=2, sigma=1.0, reward_type="known"):
        self.env = env
        self.K = K
        self.M = M
        self.sigma = sigma
        self.S = self.env.nState
        self.A = self.env.nAction
        self.H = self.env.epLen
        self.d = self.S * self.A
        self.Q = np.zeros((self.H+1,self.S,self.A)) 
        self.theta_tilde = np.zeros((self.H,self.M,self.d))
        self.lam = 1.0
        self.Sigma = np.zeros((self.H,self.d,self.d))
        self.target = np.zeros((self.H,self.d))
        for h in range(self.H):
            self.Sigma[h] = self.lam * np.identity(self.d)
        self.T = np.ones((self.H,self.d),dtype=int)
        self.buffer = np.zeros((self.H,4),dtype=int)
        self.b = np.zeros((self.K,self.H,4),dtype=int)
        self.V = np.zeros((self.H,self.S,self.A))
        self.phi = np.identity(self.d)
        self.Z = np.zeros((self.H,self.M,self.d))
        self.init_rand = np.zeros(self.d)
        self.det = 0
        # Empirical Reward
        self.reward_type = reward_type
        self.experienced_rewards = {(s,a) : 0.0 for s in self.env.states.keys() for a in range(self.env.nAction)} 
        self.experienced_count = {(s,a) : 0.0 for s in self.env.states.keys() for a in range(self.env.nAction)}         
        
    
    def act(self,s,h):
        return self.env.argmax(self.Q[h,s,:])
    
    def init_perturb(self):
        for s in range(self.S):
            for a in range(self.A):
                row = int(s*self.A + a)
                self.init_rand = self.init_rand + self.phi[row,:]*np.random.normal(0.0,self.sigma**2)
    
    def update_buffer(self,s,a,r,s_,h,k):
        '''
        Stores all the data observed for use in LSVI
        '''
        if s_ == None:
            s_ = -1
        #This buffer is for the sample covariance matrix Sigma
        self.buffer[h,0] = s
        self.buffer[h,1] = a
        self.buffer[h,2] = r
        self.buffer[h,3] = s_
        self.experienced_rewards[(s,a)] = (self.experienced_rewards[(s,a)]*self.experienced_count[(s,a)] + r)/(self.experienced_count[(s,a)] + 1)
        self.experienced_count[(s,a)] += 1
        '''
        #This buffer is for computing the targets in LSVI
        self.b[k,h,0] = s
        self.b[k,h,1] = a
        self.b[k,h,2] = r
        self.b[k,h,3] = s_
        '''
    '''
    def get_Noise(self,h,k):
        Z = np.zeros(self.d)
        for i in range(k):
            s,a = self.b[i,h,0], self.b[i,h,1]
            row = int(s*self.A + a)
            Z = Z + self.phi[row]*(np.random.normal(0.0,self.sigma**2))
        return Z
    '''
    
    def update(self,k):
        #Step 4 of PH-RLSVI
        for h in range(self.H-1,-1,-1):
            #Gets data from the buffer for computing Sigma
            s,a,_,s_ = self.buffer[h,0],self.buffer[h,1],self.buffer[h,2],self.buffer[h,3]
            if self.buffer[h,3] == -1:
                s_ = None
            #print(s,a,r,s_)
            
            #For getting the feature associated with (s,a)
            row = s*self.A + a
            #self.T[h,s,a] = int(self.T[h,s,a] + 1)
            self.T[h,row] = self.T[h,row] + 1
            #Step 6: Iteratively updates Sigma 
            self.Sigma[h] = np.add(self.Sigma[h] , np.outer(self.phi[row,:],self.phi[row,:]))
            
            #For use in step 7: computes Sigma^(-1)
            Sigma_inv = np.linalg.inv(self.Sigma[h])

            if self.reward_type == "known":
                r = self.env.R[s,a][0]
            elif self.reward_type == "unknown":
                r = self.experienced_rewards[(s,a)]

            #Step 7: Computing theta_tilde
            self.target[h] = np.add(self.target[h] , self.phi[row,:]*(r + max(self.Q[h+1,s_,:])))
            
            for m in range(self.M):
                #self.Z[h,m,:] = self.Z[h,m,:] + self.phi[row,:]*np.random.normal(0.0,self.sigma**2)
                #self.init_perturb()
                Z = np.zeros(self.d)
                for i in range(self.d):
                    #Z = np.add(Z , self.phi[i,:]*np.sum(np.random.normal(0.0,self.sigma,size=self.T[h,i])))
                    Z = np.add(Z , self.phi[i,:]*np.random.normal(0.0,np.sqrt(self.T[h,i])*self.sigma))
                self.theta_tilde[h,m,:] = np.dot(1/self.sigma**2 * Sigma_inv,np.add(self.target[h],Z))

            #Step 8: Compute Q_optimistic
            for s_ in range(self.S):
                for a_ in range(self.A):
                    maxs = -9999999
                    row = s_*self.A + a_

                    #Iterates through all the theta_tilde's picking the optimistic one
                    for m in range(self.M):
                        temps = np.inner(self.phi[row,:],self.theta_tilde[h,m,:])
                        if temps > maxs:
                            maxs = temps
                    #Computes Q_optimistic given the optimistic theta_tilde
                    self.Q[h,s_,a_] = max(0,min(maxs,self.H-h))
                
    
    def run(self):
        '''
        This function runs the agent until episode K is reached
        '''
        print("PHE-LSVI")
        #Stores the rewards for plotting
        E_return = []        
        
        for k in range(self.K):
            
            R = 0            
            self.env.reset()
            done = 0
            
            while not done:
                s = self.env.state
                h = self.env.timestep
                a = self.act(s,h)
                r,s_,done = self.env.advance(a)
                self.update_buffer(s,a,r,s_,h,k)
                R += r
                
            #updates Q
            self.update(k)
            E_return.append(R)
        return E_return
