from turtle import st
import numpy as np

class UC_HRL():
    
    def __init__(self, env, K, c, lam):
        self.env = env
        self.K = K        
        self.delta = 1.0 / self.K
        self.d = self.env.nState * self.env.nAction
        
        # phi features
        self.phi = {(s,a): np.zeros(self.d) for s in self.env.states.keys() for a in range(self.env.nAction)}
        i = 0
        for key in self.phi.keys():
            self.phi[key][i] = 1
            i += 1
                    
        # Initialize our Q matrix
        self.Q = {(h,s,a): 0.0 for h in range(self.env.epLen+1) for s in self.env.states.keys() \
                   for a in range(self.env.nAction)}

        # gram matrix
        self.lam = lam
        self.A = {i:  self.lam * np.identity(self.d) for i in self.env.eq_states.keys()}
        self.Ainv = {i:  np.linalg.inv(self.A[i]) for i in self.env.eq_states.keys()}
        
        # unknown measure
        self.mu = {i: np.zeros((self.d, self.env.nState)) for i in self.env.eq_states.keys() } 
        self.sums = {i: np.zeros((self.d, self.env.nState)) for i in self.env.eq_states.keys() } 
        
        # confidence raidus param
        self.c = c       
        
    def act(self, s, h):
        """
        a function that returns the argmax of Q given the state and timestep
        """
        return np.argmax(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))

    def proj(self, x, lo, hi):
        '''Projects the value of x into the [lo,hi] interval'''
        return max(min(x, hi), lo)

    def compute_Q(self, k):
        """
        a function that computes the Optimistic Q-values, see step 4 and Eq 5
        """
        Q = {(h, s, a): 0.0 for h in range(self.env.epLen) for s in self.env.states.keys() for a in range(self.env.nAction)}
        V = {h: np.zeros(self.env.nState) for h in range(self.env.epLen + 1)}
           
        for h in range(self.env.epLen-1, -1, -1):
            for s in self.env.states.keys():
                eq_s = self.env.eq_states[s]
                for a in range(self.env.nAction):
                    r = self.env.R[s,a][0]
                    eq_P = np.dot(self.phi[(eq_s,a)].T, self.mu[eq_s])    
                    shift = s - eq_s
                    assert shift >=0
                    P = np.roll(eq_P, shift) 
                    val = np.dot(P, V[h+1])
                    
                    Ainv = self.Ainv[eq_s]
                    w_khi = np.sqrt(np.dot(np.dot(self.phi[(eq_s,a)].T, Ainv), self.phi[(eq_s,a)])) 
                    bonus = self.Beta(k)*w_khi
                                 
                    Q[h, s, a] = self.proj(r + val + bonus, 0, self.env.epLen)
                V[h][s] = max(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))
        self.Q = Q.copy()

    def Beta(self, k):
        return self.c*self.env.epLen*np.sqrt(self.d*1)*np.log(self.d*k*self.env.epLen)

    def update_matrix(self, s, a, r, ns):
        """
        a function that performs step 10 and 11
        """
        eq_s = self.env.eq_states[s]
        # gram matrix inverse: Sherman-Morrison Update
        self.Ainv[eq_s] = self.Ainv[eq_s] - np.dot((np.outer(np.dot(self.Ainv[eq_s],self.phi[(eq_s,a)]) ,self.phi[(eq_s,a)])),self.Ainv[eq_s]) / \
                    (1 + np.dot(np.dot(self.phi[(eq_s,a)],self.Ainv[eq_s]),self.phi[(eq_s,a)]))
        # estimated matrix
        shift = s - eq_s
        assert shift >=0
        onehot_eq_ns = np.zeros(self.env.nState)
        eq_ns = ns - shift
        onehot_eq_ns[eq_ns] = 1
        self.sums[eq_s] = self.sums[eq_s] + np.outer(self.phi[(eq_s,a)], onehot_eq_ns)
        self.mu[eq_s] = np.matmul(self.Ainv[eq_s], self.sums[eq_s])      
                        
    def run(self): 
        print("UC_HRL")
        episode_return = []
        
        for k in range(1,self.K+1):
                    
            self.env.reset()
            done = 0
            R = 0
                      
            while not done:
                s = self.env.state
                h = self.env.timestep
                a = self.act(s,h)                   
                r, s_, done = self.env.advance(a)
                
                R += r
                self.update_matrix(s, a, r, s_)
                
            episode_return.append(R)            
            self.compute_Q(k)
            
        return episode_return        