import numpy as np
from tqdm.notebook import tqdm
import random
import time

class ModelBasedGreedy:
    def __init__(self, env, K, delta, uniform):
        self.env = env
        self.H= env.epLen
        self.S = env.nState
        self.A = env.nAction
        self.K = K
        self.delta = delta
        self.currentepisode = 0
        self.Rsum = np.zeros((self.S, self.A))
        self.R2sum = np.zeros((self.S, self.A))
        self.Q = np.zeros((self.S, self.A))
        self.V = np.zeros(self.S)
        self.N = np.zeros((self.S, self.A), np.int32)
        self.transitions = np.zeros((self.S, self.A, self.S), np.int32)
        self.policy = np.zeros((self.H, self.S), np.int32)
        self.uniform = uniform
    
    def Vmax(self, h, nextstep = False):
        return self.H if not self.uniform else self.H - h - 1 if nextstep else self.H - h
    
    def run(self):
        episode_return = []
        self.currentepisode = 0
        for k in range(1,self.K+1):
            self.currentepisode += 1
            self.env.reset()
            done = 0

            self.computepolicy()

            while not done:
                state = self.env.state
                timestep = self.env.timestep
                action =  self.policy[timestep, state]
                r, next_state, done = self.env.advance(action)
                self.N[state, action] += 1
                self.Rsum[state, action] += r
                self.R2sum[state, action] += r**2
                self.transitions[state, action, next_state] += 1

            R = self.env.evaluate(self.policy)[0, 0]
            episode_return.append(R)
            
        return episode_return
    
    def argmax(self, b):
        return random.choice(np.where(b == b.max())[0])    
            
class UCRL2(ModelBasedGreedy):
    def __init__(self, env, K, delta, uniform):
        super().__init__(env, K, delta, uniform)
        self.V_next = np.zeros(self.S)
        
    def computepolicy(self):
        #EXT VI
        mxN1 = np.maximum(self.N, 1)
        tk = (self.currentepisode - 1) * self.H + 1
        d = np.sqrt(14 * self.S * np.log(2 * self.A * tk / self.delta) / mxN1)
        r_tilde = np.minimum(self.Rsum / mxN1 + np.sqrt( 3.5 * np.log(2 * self.S * self.A * tk / self.delta) / mxN1), 1)
        Phat = self.transitions / mxN1.reshape(self.S, self.A, 1)
        self.V.fill(0)
        self.V_next.fill(0)
        while True:
            sortIndex = self.V.argsort()
            s1 = sortIndex[-1]
            Ptilde = Phat.copy()
            Ptilde[:, :, s1] = np.minimum(1.0, Ptilde[:, :, s1] + d / 2)
            for s in range(0, self.S):
                for a in range(0, self.A):
                    probsum = np.sum(Ptilde[s, a])
                    for s2 in sortIndex:
                        if probsum < 1 + 1e-9 :
                            break
                        updP = max(0.0, 1.0 - probsum + Ptilde[s, a, s2])
                        probsum += updP - Ptilde[s, a, s2]
                        Ptilde[s, a, s2] = updP
            self.Q = r_tilde + Ptilde @ self.V
            self.V_next = np.max(self.Q, axis = -1)
            V_diff = self.V_next - self.V
            if np.max(V_diff) - np.min(V_diff) <= np.sqrt(1 / tk):
                break
            self.V = self.V_next

        self.V.fill(0)
        for h in range(self.H-1, -1, -1):
            self.Q = r_tilde + Ptilde @ self.V
            for s in range(self.S):
                self.policy[h, s] = self.argmax(self.Q[s])
            self.V = np.max(self.Q, axis = -1)

class UCBVI_Azar(ModelBasedGreedy):
    def __init__(self, env, K, delta, uniform):
        super().__init__(env, K, delta, uniform)
        self.L = np.log(5 * self.S * self.A * self.H * self.K / self.delta)
        self.prevQ = np.ndarray((self.H, self.S, self.A))
        for h in range(self.H - 1, -1, -1):
            self.prevQ[h] = self.Vmax(h)
    
    def bonus(self, Var, mxN1, Vmax):
        return np.sqrt(8 * self.L * Var / mxN1) + (14 / 3) * Vmax * self.L / mxN1 + Vmax* np.sqrt(8 / mxN1)
    
    def computepolicy(self):
        self.V.fill(0)
        mxN1 = np.maximum(self.N, 1)
        Phat = self.transitions / mxN1.reshape(self.S, self.A, 1)
        rhat = self.Rsum / mxN1
        for h in range(self.H - 1, -1, -1):
            meanV = Phat @ self.V
            VarV = np.maximum(Phat @ (self.V**2) - meanV ** 2, 0)
            self.Q = np.minimum(np.where(self.N == 0, self.Vmax(h), rhat + meanV + self.bonus(VarV, mxN1, self.Vmax(h, True))), self.prevQ[h])
            self.prevQ[h] = self.Q

            for s in range(self.S):
                self.policy[h, s] = self.argmax(self.Q[s])
            self.V = self.Q.max(axis = -1)
    
class MVP(ModelBasedGreedy):
    def __init__(self, env, K, delta, uniform):
        super().__init__(env, K, delta, uniform)
        log2KH = np.log2(self.K * self.H)
        self.L = np.log( ( 10 * self.S **2 * self.A * self.H * (log2KH + 2) + 6 * (log2KH + 1) * log2KH + 1) / self.delta)
        
    
    def bonus(self, VarV, rhat, mxN1, Vmax):
        return (460/9) * np.sqrt(self.L * VarV / mxN1) + np.sqrt(8 * rhat * self.L / mxN1) +  (544/9) * Vmax * self.L / mxN1
    
    def computepolicy(self):
        self.V.fill(0)
        mxN1 = np.maximum(self.N, 1)
        rhat = self.Rsum / mxN1
        Phat = self.transitions / mxN1.reshape(self.S, self.A, 1)

        for h in range(self.H - 1, -1, -1):
            meanV = Phat @ self.V
            VarV = np.maximum(Phat @ (self.V**2) - meanV ** 2, 0)
            self.Q = np.minimum(rhat + meanV + self.bonus(VarV, rhat, mxN1, self.Vmax(h, True)), self.Vmax(h))
            for s in range(self.S):
                self.policy[h, s] = self.argmax(self.Q[s])
            self.V = self.Q.max(axis = -1)

class EQO_Kaware(ModelBasedGreedy):
    def __init__(self, env, K, delta, uniform):
        super().__init__(env, K, delta, uniform)
        L = np.log( 24 * self.H * self.S * self.A / self.delta)
        lamb = min(1, 5 * np.sqrt( self.S * self.A * np.log(1 + self.K * self.H / (self.S * self.A)) * L / self.K))
        self.c = 7 * L / lamb   # H or H - h is multiplied depending on self.uniform
    
    def computepolicy(self):
        self.V.fill(0)
        mxN1 = np.maximum(self.N, 1)
        rhat = self.Rsum / mxN1
        b = self.c / mxN1
        Phat = self.transitions / mxN1.reshape(self.S, self.A, 1)
        for h in range(self.H - 1, -1, -1):
            Vmax = self.Vmax(h)
            self.Q = np.where(self.N == 0, Vmax, np.minimum(rhat + b * Vmax + Phat @ self.V, Vmax))
            for s in range(self.S):
                self.policy[h, s] = self.argmax(self.Q[s])
            self.V = self.Q.max(axis = -1)
                
class EQO_Kunaware(ModelBasedGreedy):
    def __init__(self, env, K, delta, uniform):
        super().__init__(env, K, delta, uniform)
        self.c = 0.0
        self.nxtupd = 1 # 1 -> 2 -> 4 -> ...
        self.fk = 0 # log_2 nxtupd
    
    
    def computepolicy(self):
        if self.currentepisode == self.nxtupd:
            self.fk += 1
            L = np.log( 24 * self.H * self.S * self.A *  self.fk ** 2/ self.delta)
            lamb = min(1, 5 *  np.sqrt( self.S * self.A * np.log(1 + self.currentepisode * self.H / (self.S * self.A)) * L / self.currentepisode))
            self.nxtupd *= 2
            self.c = 7 * L / lamb # H or H - h is multiplied depending on self.uniform
            
        self.V.fill(0)
        mxN1 = np.maximum(self.N, 1)
        rhat = self.Rsum / mxN1
        b = self.c / mxN1
        Phat = self.transitions / mxN1.reshape(self.S, self.A, 1)
        for h in range(self.H - 1, -1, -1):
            Vmax = self.Vmax(h)
            self.Q = np.where(self.N == 0, Vmax, np.minimum(rhat + b * Vmax + Phat @ self.V, Vmax))
            for s in range(self.S):
                self.policy[h, s] = self.argmax(self.Q[s])
            self.V = self.Q.max(axis = -1)

class Euler(ModelBasedGreedy):
    def __init__(self, env, K, delta, uniform):
        super().__init__(env, K, delta, uniform)
        self.Vdown = np.zeros(self.S)
        self.Qdown = np.zeros((self.S, self.A))
        self.L = np.log(28 * self.S * self.A * self.H * self.K / self.delta)
    
    def bonus(self, VarR, VarV, VarVdown, mxN2, l2, Vmax):
        phi = np.sqrt(2 * self.L * VarV / mxN2) + Vmax  * self.L / 3 / (mxN2 - 1)
        phidown = np.sqrt(2 * self.L * VarVdown / mxN2) + Vmax  * self.L / 3 / (mxN2 - 1)
        br = np.sqrt(2 * self.L * VarR / mxN2) + (7 / 3) * self.L / (mxN2 - 1)
        bcommon = (4 * Vmax * (self.L - np.log(3)) + Vmax * np.sqrt(2 * self.L)) / mxN2 \
            + np.sqrt(2 * self.L * l2 / mxN2)
        return br + phi + bcommon, br + phidown + bcommon
    
    def computepolicy(self):
        self.V.fill(0)
        self.Vdown.fill(0)
        mxN2 = np.maximum(self.N, 2)
        Phat = self.transitions / mxN2.reshape(self.S, self.A, 1)
        rhat = self.Rsum / mxN2
        VarR = np.maximum(self.R2sum / mxN2 - rhat ** 2, 0)
        for h in range(self.H-1, -1, -1):
            meanV = Phat @ self.V
            VarV = np.maximum(Phat @ (self.V ** 2) - meanV**2, 0)
            meanVdown = Phat @ self.Vdown
            VarVdown = np.maximum(Phat @ (self.Vdown ** 2) - meanVdown ** 2, 0)
            l2norm = Phat @ ((self.V - self.Vdown)**2)

            mask = self.N <= 1
            bonus, bonusdown = self.bonus(VarR, VarV, VarVdown, mxN2, l2norm, self.Vmax(h, True))
            self.Q = np.where(mask, self.Vmax(h), np.minimum(rhat + meanV + bonus, self.Vmax(h)))
            self.Qdown = np.where(mask, 0, np.maximum(rhat + meanVdown - bonusdown, 0))
            for s in range(self.S):
                self.policy[h, s] = self.argmax(self.Q[s])
            self.V = self.Q.max(axis = -1)
            self.Vdown = self.Qdown[np.arange(self.S), self.policy[h]]

class ORLC(ModelBasedGreedy):
    def __init__(self, env, K, delta, uniform):
        super().__init__(env, K, delta, uniform)
        self.Vdown = np.zeros(self.S)
        self.Qdown = np.zeros((self.S, self.A))
        self.L = np.log(26 * self.S * self.A * (self.H + self.S + 1) / self.delta)
    
    def bonus(self, VarV, mxN1, l1, PVdiff, sqrtPVdiff, PVdiffsq, Vmax):
        phi2 = np.minimum(1, 0.52 *( 1.4 * np.log(np.maximum(1, np.log(mxN1))) + self.L) / mxN1)
        phi = np.sqrt(phi2)

        std2 = np.sqrt(12 * (VarV + PVdiffsq))
        common = (1 + np.sqrt(12 * VarV)) * phi + PVdiff / self.H

        psiup = np.minimum.reduce([\
            (Vmax + 1) * phi,\
            (1 + std2) * phi + 8.13 * Vmax * phi2, \
            common + 20.13 * self.H * l1 * phi2\
        ])

        psidown = np.minimum.reduce([\
            (2 * np.sqrt(self.S) * Vmax + 1) * phi,\
            (Vmax + 1 + 2 * sqrtPVdiff) * phi + 4.66 * l1 * phi2,\
            (std2 + 1 + 2 * sqrtPVdiff) * phi + (8.13 * Vmax + 4.66 * l1) * phi2,\
            common + (8.13 * Vmax + (32*self.H + 4.66) * l1) * phi2 \
        ])
        return psiup, psidown
    
    def computepolicy(self):
        self.V.fill(0)
        self.Vdown.fill(0)
        mxN1 = np.maximum(self.N, 1)
        Phat = self.transitions / mxN1.reshape(self.S, self.A, 1)
        sqrtPhat = np.sqrt(Phat)
        rhat = self.Rsum / mxN1
        for h in range(self.H - 1, -1, -1):
            meanV = Phat @ self.V
            VarV = np.maximum(Phat @ (self.V ** 2) - meanV**2, 0)
            meanVdown = Phat @ self.Vdown
            Vdiff = self.V - self.Vdown
            l1norm = np.sum(np.abs(Vdiff))
            PVdiff = Phat @ Vdiff
            PVdiffsq = Phat @ (Vdiff **2)
            sqrtPVdiff = sqrtPhat @ Vdiff

            mask = self.N <= 1
            bonus, bonusdown = self.bonus(VarV, mxN1, l1norm, PVdiff, sqrtPVdiff, PVdiffsq, self.Vmax(h, True))
            self.Q = np.where(mask, self.Vmax(h), np.minimum(rhat + meanV + bonus, self.Vmax(h)))
            self.Qdown = np.where(mask, 0, np.maximum(rhat + meanVdown - bonusdown, 0))
            for s in range(self.S):
                self.policy[h, s] = self.argmax(self.Q[s])
            self.V = self.Q.max(axis = -1)
            self.Vdown = self.Qdown[np.arange(self.S), self.policy[h]]