from environment import MDP, generate_mdp, initial_estimate_online,update
from PI import policy_iteration
from sampling import omega_star_o, oracle_policy, C_navigation,C_navigation2, D_navigation
from stopping import stop
from utils import Mdist, Odist, Pidist
import time
import numpy as np

def MDP_NaS(mdp, Navigation = "D", delta= 1e-1, period = 1e3):
    ## GROUNDTRUTH VARIABLES FOR COMPARING, NOT USED IN THE ALGORITHM
    pi_star, V, Q = policy_iteration(mdp)
    omega_o, _, _, _ = omega_star_o(mdp,pi_star, V, Q, omega_0 = None, N_iter=100)
    pi_o = oracle_policy(omega_o)

    
    #### BEGINNING OF THE ALGO
    Ns = mdp.Ns
    Na = mdp.Na
    M_hat = initial_estimate_online(mdp)
    pi_hat, V_hat, Q_hat = policy_iteration(M_hat)
    omega_o_hat, _, H, H_star= omega_star_o(M_hat,pi_hat,  V_hat, Q_hat,omega_0 = None, N_iter=100)
    pi_o_hat = oracle_policy(omega_o_hat)
    sum_oracle_pi = np.zeros((Ns,Na))
    sum_oracle_omega = np.zeros((Ns,Na))
    visits = np.zeros((Ns,Na))
    t = 0
    LAST = 0
    STOP = False
    CORRECT = False
    EARLIEST = 0
    t0 = time.time()
    s = mdp.current()
    LOGS = {"time": [],"U": [], "beta":[], "MDP_d":[], "pi_d": [],\
            "omega_d": [], "EARLIEST": [], "pi_hat": pi_hat}
    while not STOP:
        if t%period == 1: # update oracle policy after every period to avoid unnecessary computation.
            #print("Period in ", time.time()-t0)
#             print("Iteration n° ", t)

            pi_hat, V_hat, Q_hat = policy_iteration(M_hat)
            omega_o_hat, _, H, H_star = omega_star_o(M_hat, pi_hat, V_hat, Q_hat,omega_0 = None, N_iter=100)
            pi_o_hat = oracle_policy(omega_o_hat)
            LOGS["time"].append(t)
            LOGS["MDP_d"].append(Mdist(M_hat,mdp))
            LOGS["pi_d"].append(Pidist(pi_o_hat, pi_o))
            LOGS["omega_d"].append(Odist(visits/t, omega_o))
            print("dist omega = {}, dist M = {}".format(Odist(visits/t, omega_o), Mdist(M_hat,mdp)))
            LOGS["pi_hat"] = pi_hat
            
            if (not CORRECT) and (pi_star==pi_hat).all():
                CORRECT = True
                LOGS["EARLIEST"] = t

#             t1 = time.time()
            STOP, U_hat, beta = stop(pi_hat, H, H_star, visits, delta, t)
            LOGS["U"].append(U_hat)
            LOGS["beta"].append(beta)
#             print("checked in", time.time()-t1)
            t0 = time.time()
        sum_oracle_pi+= pi_o_hat
        sum_oracle_omega+= omega_o_hat
        
        #t1 = time.time()
        if Navigation == "D":
            a = D_navigation(s, pi_o_hat, t, alpha = 1/2)
        elif Navigation == "C":
            a = C_navigation(s, sum_oracle_pi, t, alpha = 1/2)
        elif Navigation == "C2":
            a = C_navigation2(s, sum_oracle_omega, t, alpha = 1/2)
        
        
        rewards, transitions, s_prime = mdp.play(a)
        #print("Navigation in ", time.time()-t1)
        #t1 = time.time()
        M_hat = update(M_hat,s,a,rewards,transitions,1,visits[s,a])
        #print("Updated in ", time.time()-t1)
        visits[s,a] += 1
        t+=1
        s = s_prime      

 
    return LOGS