import numpy as np
from numpy.random import default_rng
import itertools
from PI import policy_iteration
import cvxpy as cp
import warnings
warnings.filterwarnings("ignore")

#upper bound of charachteristic time
def U(pi,H,H_star, omega):
    Ns = omega.shape[0]
    Na = omega.shape[1]
    M = 0
    M_star = 0
    smax = 0
    amax = 0
    tilde_smax = 0
    for s in range(Ns):
        x = H_star/omega[s,pi[s]]
        if  x > M_star:
            M_star = x
            tilde_smax = s
        for a in range(Na):
            if a != pi[s]:
                y = H[s,a]/omega[s,a]
                if y > M:
                    M = y
                    smax = s
                    amax = a
    U = M + M_star #value of the upper bound function
    if (omega==0).any():
        U = np.float('inf')
    G = np.zeros((Ns,Na)) #subgradient
    G[smax,amax] = -H[smax,amax]/(omega[smax,amax]**2) 
    G[tilde_smax,pi[tilde_smax]] = -H_star/(omega[tilde_smax,pi[tilde_smax]]**2)
    return U,G


def FrankWolfe(G, P): #Franke Wolfe step: G = gradient, P= transition kernel for the constraints
    Ns = P.shape[0]
    Na = P.shape[1]
    G = G.flatten()
    omega = cp.Variable((Ns*Na,1))
    objective = cp.Minimize(G.T@omega)
    A = np.zeros((Ns, Ns*Na))
    for s in range(Ns):
        A[s,s*Na:(s+1)*Na] = np.ones(Na)

    B = np.zeros((Ns, Ns*Na))
    for s in range(Ns):
        for x in range(Ns):
            B[s, x*Na:(x+1)*Na] = P[x,:,s]
    one = np.ones((Ns*Na,1))
    constraints = [(A-B)@omega == 0, one.T@omega==1, omega >= 0]
    prob = cp.Problem(objective, constraints)
    prob.solve(solver='CVXOPT')
#     print("optimal value", prob.value)
#     print(("optimal solution", omega.value))
    return omega.value

def omega_star_o(mdp,pi, V, Q, omega_0 = None, N_iter=100):
    Ns = mdp.Ns
    Na = mdp.Na
    P = mdp.P
    H,H_star = HHstar(mdp,pi, V, Q)
    if omega_0 is None:
        Pu = np.mean(P,axis=1)
        omega_0 = np.linalg.matrix_power(Pu,20)[0]
        omega_0 = (1/Na)*np.repeat(omega_0,Na).reshape(Ns,Na)
        omega_0 = omega_0/np.sum(omega_0)
    omega = omega_0
#     print("Omega_0 = ", omega_0)
    Objective = np.zeros((N_iter,2))
    for n in range(N_iter):
        u , G = U(pi,H, H_star, omega)
        G = G/np.sqrt(np.sum(G**2))
        Objective[n,0] = n
        Objective[n,1] = u
        eta = 0.5/(n+1)
        omega = (1-eta)*omega + eta*FrankWolfe(G, P).reshape(Ns,Na)
        omega = omega/np.sum(omega) # normalize to avoid amplified errors at the end of the loop
    return omega, Objective[-1,1], H, H_star

def oracle_policy(omega):
    Ns = omega.shape[0]
    Na = omega.shape[1]
    pi_o = omega/np.repeat(np.sum(omega, axis=1),Na).reshape(Ns,Na)
    return pi_o

def C_navigation(s, sum_oracle_pi, t, alpha = 0.9):
    rng = default_rng()
    Ns = sum_oracle_pi.shape[0]
    Na = sum_oracle_pi.shape[1]
    pi_u = np.ones((Ns,Na))/Na
    if t==0:
        pi = pi_u
    else:
        epsilon = 1/((t+1)**alpha)
        pi = epsilon*pi_u + (1-epsilon)*sum_oracle_pi/(t+1)
    a = np.where(rng.multinomial(1, pi[s,:])==1)[0][0]
    return a

def C_navigation2(s, sum_oracle_omega, t, alpha = 0.9):
    rng = default_rng()
    Ns = sum_oracle_omega.shape[0]
    Na = sum_oracle_omega.shape[1]
    piOfMean = oracle_policy(sum_oracle_omega)
    pi_u = np.ones((Ns,Na))/Na
    if t==0:
        pi = pi_u
    else:
        epsilon = 1/((t+1)**alpha)
        pi = epsilon*pi_u + (1-epsilon)*piOfMean
    a = np.where(rng.multinomial(1, pi[s,:])==1)[0][0]
    return a

def D_navigation(s, oracle_pi, t, alpha = 0.9):
    rng = default_rng()
    Ns = oracle_pi.shape[0]
    Na = oracle_pi.shape[1]
    pi_u = np.ones((Ns,Na))/Na
    epsilon = 1/((t+1)**alpha)
    pi = epsilon*pi_u + (1-epsilon)*oracle_pi
    a = np.where(rng.multinomial(1, pi[s,:])==1)[0][0]
    return a

## COMPUTING HARDNESS INDICES OF STATE-ACTION PAIRS USED IN THE FORMULA OF THE OPTIMAL ALLOCATION VECTOR
def HHstar(mdp,pi, V, Q) :
    Ns = mdp.Ns
    Na = mdp.Na
    gamma = mdp.gamma
    P = mdp.P
    R = mdp.R
    
    # matrix of gaps rounded to 6 decimals 
    delta_mcal =  np.round(V-Q,6) 
    non_zero = np.where(delta_mcal !=0, delta_mcal, (2/(1-gamma))*np.ones((Ns,Na))) # looking for non null gaps
    delta_min = np.min(non_zero) # minimum gap
    if delta_min==0: # pathological case where all gaps are zero
        delta_min = 1/(1-gamma)
    
    #  matrix of Variance of next-state value function
    X = np.tile(V[:,0],[Ns,Na]).reshape(Ns,Na,Ns)
    Y = P*X
    Var = np.sum(P*(X**2), axis=2) - Y.sum(axis=2)**2
    #matrix of maximum deviation of next-state value function (similar to the span, see def in [MP20])
    MD = np.max([np.max(V)-np.sum(Y, axis=2),np.sum(Y, axis=2) -np.min(V)],axis=0)
    
    
    # Bound Terms
    T1 =  2/(delta_mcal**2) # matrix of the first term T_1(s,a) used in the sampling and stopping rules 
    T1 = np.where(T1 != float("inf"), T1,np.zeros((Ns,Na)))
    T1 = np.where(np.isnan(T1),np.zeros((Ns,Na)),T1)
    Z = np.stack((16*Var/(delta_mcal**2),6*(MD/delta_mcal)**(4/3)))
    T2 = np.max(Z,axis=0)# matrix of the second term T_2(s,a) used in the sampling and stopping rules
    T2 = np.where(T2 != float("inf"), T2,np.zeros((Ns,Na)))
    T2 = np.where(np.isnan(T2),np.zeros((Ns,Na)),T2)
    T3 = 2/(((1-gamma)*delta_min)**2) # Third term  used in the sampling and stopping rules   
    
    # Fourth term used in the sampling and stopping rules 
    
    Var_max =  0 # maximum variance of optimal (state,action) pairs
    MD_max = 0 # maximum maximum-deviation of optimal (state,action) pairs
    for s in range(Ns): 
        if Var[s,pi[s]] > Var_max :
            Var_max = Var[s,pi[s]]
        if MD[s,pi[s]] > MD_max :
            MD_max = MD[s,pi[s]]
    
    V1 = 27/((delta_min**2)*((1-gamma)**3))
    V2 = max(16*Var_max/((delta_min*(1-gamma))**2),6*(MD_max/(delta_min*(1-gamma)))**(4/3))
    T4 = min(V1,V2) 
    
    H = T1 + T2 #matrix of state-action hardness indices, useful to compute optimal sampling weights
    H_star = (T3 + T4)*Ns


    return H, H_star


def omega_star_g(mdp,pi, V, Q): 
    Ns = mdp.Ns
    Na = mdp.Na
    gamma = mdp.gamma
    P = mdp.P
    R = mdp.R
    H,H_star = HHstar(mdp,pi, V, Q)
    SUM1 =  np.sum(H)# sum the H_sa 
    if SUM1 ==0: # pathological case where all gaps are zero hence H^\star ==0  (we replaced infty values by zero in T1 and T2)
        omega = (1/(Ns*Na))*np.ones((Ns,Na))
        return omega, np.array([float("inf")]),np.array([float("inf")]),np.array([float("inf")]) 
        
    SUM2 = np.sqrt(SUM1*H_star)
    D = SUM1 + SUM2 #DENOMINATOR IN THE FORMULA OF OPTIMAL OMEGA
    omega = np.zeros((Ns,Na))
    for s,a in itertools.product(range(Ns),range(Na)):
        omega[s,a] = H[s,a]/D
        if H[s,a] == 0: 
            #this means that delta_mcal[s,a] is zero because the pair (s,a) is optimal
            omega[s,a] = SUM2/(D*Ns)
            
            
    U_g = SUM1 + H_star + 2*SUM2 #THEORETICAL UPPER-BOUND OF THE Algorithm (see proof of Corollary 1 in [MP20])
    return omega,U_g, H, H_star


#C_tracking rule
def C_tracking(sum_omegas,visits, t):
    Ns = visits.shape[0]
    Na = visits.shape[1]
    Nsa = Ns * Na
    starved = False 
    for s,a in itertools.product(range(Ns),range(Na)):
        if visits[s,a] < t**(1/2) - Nsa/2 : # if some pair is starving, sample from it
            starved = True
            return s,a, starved
    
    # Otherwise sample from the pair whose number of visits is far behind its weight
    y = sum_omegas - visits
    idx = np.unravel_index(np.argmax(y, axis=None), y.shape)
    s = idx[0]
    a = idx[1]
    return s,a,starved

#D_tracking rule (#In practice yield the same results as C-tracking)
def D_tracking(omega,visits, t):
    Ns = visits.shape[0]
    Na = visits.shape[1]
    Nsa = Ns * Na
    starved = False 
    for s,a in itertools.product(range(Ns),range(Na)):
        if visits[s,a] < t**(1/2) - Nsa/2 : # if some pair is starving, sample from it
            starved = True
            return s,a, starved
    
    # Otherwise sample from the pair whose number of visits is far behind its weight
    y = t*omega - visits
    idx = np.unravel_index(np.argmax(y, axis=None), y.shape)
    s = idx[0]
    a = idx[1]
    return s,a,starved
