import numpy as np


def importance_test(env, w_s_a, test_episode, behav_policy=None):
    
    w = w_s_a
    w = w.reshape((env.state_size, env.action_size))
    if behav_policy is None:
        behav_policy = np.ones_like(w) / env.action_size
    pi_s_a = np.zeros_like(w)
    for i in range(w.shape[0]):
        for j in range(w.shape[1]):
            if np.sum(w[i,:]*behav_policy[i,:]) == 0:
                pi_s_a[i][j] = (w[i][j]*behav_policy[i][j])  / (np.sum(w[i,:]*behav_policy[i,:]) + 1e-9)
            else:
                pi_s_a[i][j] = (w[i][j]*behav_policy[i][j])  / np.sum(w[i,:]*behav_policy[i,:])


    goal_count = 0
    hole_count = 0
    
    for _ in range(test_episode):  
        
        observation = env.reset()
        done = False
        hole = False
        time_step = 0
        while not done and time_step < env.test_time_step:
            
            if pi_s_a[observation].sum() == 0:
                action = np.random.choice(range(env.action_size), size=1)[0]
            else:
                action = np.random.choice(range(env.action_size), size=1, p=pi_s_a[observation])[0]
            new_observation, reward, cost, done, hole = env.step[observation][action]
            s_a = observation*env.action_size + action

            if (done ==True):
                goal = True
                goal_count += 1

            if (hole == True):
                hole_count += 1
                break
                
            observation = new_observation
            time_step += 1

    # print("rate of fail: %.2f" % (hole_count / test_episode))
    # print("rate of success: %.2f" % (goal_count / test_episode))
    
    return goal_count / test_episode