"""
Based on 
https://github.com/aa14k/Exploration-in-RL.git
"""

from turtle import st
import numpy as np

class UC_MatrixRL():
    
    def __init__(self, K, env, c, lam):
        self.env = env
        self.K = K        
        self.d1 = self.env.nState*self.env.nAction
        self.d2 = self.env.nState
        
        # phi features
        self.phi = {(s,a): np.zeros(self.d1) for s in self.env.states.keys() for a in range(self.env.nAction)}
        i = 0
        for key in self.phi.keys():
            self.phi[key][i] = 1
            i += 1
            
        # psi features
        self.psi = {(s): np.zeros(self.d2) for s in self.env.states.keys()}
        j = 0
        for key in self.psi.keys():
            self.psi[key][j] = 1
            j += 1
            
        self.psi_mat = np.identity(self.d2)*1e-5
        for s in self.psi.keys():
            self.psi_mat = self.psi_mat + np.outer(self.psi[s], self.psi[s])
        self.psi_inv = np.linalg.inv(self.psi_mat)          
                    
        # Initialize our Q matrix
        self.Q = {(h,s,a): 0.0 for h in range(self.env.epLen) for s in self.psi.keys()
                  for a in range(self.env.nAction)}        

        # gram matrix
        self.lam = lam
        self.A = self.lam * np.identity(self.d1)
        self.Ainv = np.linalg.inv(self.A)
        
        
        # transition core
        self.M = np.zeros((self.d1, self.d2))
        
        # confidence raidus param
        self.c = c
        
        self.sums = np.zeros((self.d1, self.d2))
        self.delta = 1.0/self.K

    def act(self, s, h):
        """
        a function that returns the argmax of Q given the state and timestep
        """
        #print(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))
        return np.argmax(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))

    def proj(self, x, lo, hi):
        '''Projects the value of x into the [lo,hi] interval'''
        
        return max(min(x, hi), lo)

    def compute_Q(self, k):
        """
        a function that computes the Optimistic Q-values, see step 6 and Eq 4 and 8
        """
        Q = {(h, s, a): 0.0 for h in range(self.env.epLen) for s in self.psi.keys() for a in range(self.env.nAction)}
        V = {h: np.zeros(self.env.nState) for h in range(self.env.epLen + 1)}
           
        for h in range(self.env.epLen-1, -1, -1):
            for s in self.psi.keys():
                for a in range(self.env.nAction):
                    r = self.env.R[s,a][0]
                    val = np.dot(np.matmul(np.dot(self.phi[(s,a)].T, self.M), self.psi_mat), V[h+1])
                    
                    w_kh = np.sqrt(np.dot(np.dot(self.phi[(s,a)], self.Ainv), self.phi[(s,a)]))
                    bonus = 2*self.c*self.env.epLen*np.sqrt(self.d2*self.Beta(k))*w_kh
                                 
                    Q[h, s, a] = self.proj(r + val + bonus, 0, self.env.epLen)
                V[h][s] = max(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))
        self.Q = Q.copy()
    
    def Beta(self, k):
        return np.log(k*self.env.epLen)*self.d1*self.d2

    def update_core_matrix(self, s, a, r, s_):
        """
        a function that performs step 12 and 13
        """
        # gram matrix inverse: Sherman-Morrison Update
        self.Ainv = self.Ainv - np.dot((np.outer(np.dot(self.Ainv,self.phi[(s,a)]) ,self.phi[(s,a)])),self.Ainv) / \
                    (1 + np.dot(np.dot(self.phi[(s,a)],self.Ainv),self.phi[(s,a)]))
        # estimated matrix
        self.sums = self.sums + np.outer(self.phi[(s,a)], self.psi[s_])
        self.M = np.matmul(np.matmul(self.Ainv, self.sums), self.psi_inv)      

    def run(self): # episode return
        print("UC-MatrixRL")
        episode_return = []
        
        for k in range(1,self.K+1):
                    
            self.env.reset()
            done = 0
            R = 0
                      
            while not done:
                s = self.env.state
                h = self.env.timestep
                a = self.act(s,h)                   
                r, s_, done = self.env.advance(a)
                
                R += r
                self.update_core_matrix(s, a, r, s_)
                
            episode_return.append(R)            
            self.compute_Q(k)
            
        return episode_return        