import numpy as np
from numpy.random import default_rng
from decimal import Decimal
import itertools
def VRQL(mdp, U_o, pi_b, hyperparameters, epsilon, delta = 0.1): #variance reduced Q-learning
    Ns = mdp.Ns
    Na = mdp.Na
    gamma = mdp.gamma
    c0, c1, c2, c3 = hyperparameters["c0"], hyperparameters["c1"], hyperparameters["c2"], hyperparameters["c3"]
    mu_min = hyperparameters["mu_min"]
    M = c3*np.log(1/(epsilon*(1-gamma)**2))
    # we run experiments MDPs with gamma > 0.5 so the terms in t_mix become negligible
    t_epoch = (c2/mu_min)*(1/(1-gamma)**3)*np.log(1/(epsilon*(1-gamma)**2))*np.log(Ns*Na/delta)
    N = (c1/mu_min)*(1/(min(1,epsilon**2)*(1-gamma)**3))*np.log(Ns*Na*t_epoch/delta)
    eta = (c0/np.log(Ns*Na*t_epoch/delta))*((1-gamma)/gamma)**2
    M = int(M)+1
    t_epoch = int(t_epoch)+1
    N = int(N)+1
    rng = default_rng()
    Qepoch = np.zeros((Ns,Na))
    HALTED = False
    N_samples = (t_epoch+N)*M
    print("Number of samples = {}".format(format(N_samples, "10E")))
    if N_samples > 10*U_o*np.log(1/delta):
        HALTED = True
        return HALTED, format(N_samples, "10E"), np.argmax(Qepoch, axis = 1), Qepoch 
    for m in range(M+1):
        Qepoch = vu_Q_RUN_EPOCH(mdp, Qepoch, pi_b, N, t_epoch, eta, gamma, rng)
        print("Iteration n° {}".format((t_epoch+N)*(m+1)))
    return HALTED, format(N_samples, "10E"), np.argmax(Qepoch, axis = 1), Qepoch   
def vu_Q_RUN_EPOCH(mdp, Qbar, pi_b, N, t_epoch, eta, gamma, rng):
    Ns = Qbar.shape[0]
    Na = Qbar.shape[1]
    s = mdp.current()
    visits = np.zeros((Ns,Na))
    R_nextQbar = [[] for s,a in itertools.product(range(Ns),range(Na))]
    #Draw initial N samples from the trajectory and record observations
    for n in range(N):
        a = np.where(rng.multinomial(1, pi_b[s,:])==1)[0][0]
        rewards, _, s_prime = mdp.play(a)
        visits[s,a]+=1
        l = R_nextQbar[a+s*Na]
        l.append(rewards[0]+ gamma*np.max(Qbar[s_prime,:]))
        R_nextQbar[a+s*Na]= l
        s = s_prime
    # Compute the baseline
    TQbar = np.zeros((Ns,Na))
    for s,a in itertools.product(range(Ns),range(Na)):
        if visits[s,a] > 0:
            TQbar[s,a] =np.mean(np.array(R_nextQbar[s,a]))
    Q = Qbar.copy()
    for t in range(t_epoch):
        a = np.where(rng.multinomial(1, pi_b[s,:])==1)[0][0]
        _, _, s_prime = mdp.play(a)
        diff = np.max(Q[s_prime,:]) - np.max(Qbar[s_prime,:])
        Q[s,a] = (1-eta)*Q[s,a]+ eta*(gamma*diff+TQbar[s,a])
        s = s_prime
    return Q

#useful functions for the hyperparameters of VRQL
def Delta_min(V,Q, gamma):
    Ns = Q.shape[0]
    Na = Q.shape[1]
    delta_mcal =  np.round(V-Q,6) 
    non_zero = np.where(delta_mcal !=0, delta_mcal, (2/(1-gamma))*np.ones((Ns,Na))) # looking for non null gaps
    delta_min = np.min(non_zero) # minimum gap
    if delta_min==0: # pathological case where all gaps are zero
        delta_min = 1/(1-gamma)
    return delta_min

def Mu_min(mdp, pi_b):
    Ns = mdp.Ns
    Na = mdp.Na
    Pb = np.zeros((Ns,Ns)) # transition kernel of behavior policy
    for s, s_prime in itertools.product(range(Ns),range(Ns)):
        for a in range(Na):
            Pb[s,s_prime]+= pi_b[s,a]*mdp.P[s,a,s_prime]
    mu = np.linalg.matrix_power(Pb,20)[0] #stationary distribution
    mu_min = np.float("inf")
    for s, a in itertools.product(range(Ns),range(Na)):
        x = mu[s]*pi_b[s,a]
        if x < mu_min:
            mu_min = x
    return mu_min