import numpy as np
from VMBPO import VMBPO
from stochastic_env import stochastic_env
from utils import compute_initial_state_value
from stochastic_cliff_env import stochastic_cliff
import matplotlib.pyplot as plt
import time

env = stochastic_env()
p_s_given_sa, r_sa = env.generate_env()
print(r_sa)
episodes = 40000
iterations = 200

learning_rate = 0.1

# VMBPO
beta = 1
print('Beta:{0}'.format(beta))
average_return = np.zeros((int(episodes/20),5))

for i in range(5):
    learning_rate = 0.1
    b = beta
    print(i)
    start_time = time.time()
    algorithm = VMBPO(env.num_states, env.num_actions, env)
    return_per_episode = []
    adaptive_beta = []
    for e in range(episodes):
        if(e % 5000 == 4999):
            learning_rate = learning_rate / 2
        if(e % 1000 == 999):
            print(e)
        if e % 20 == 0:
            return_per_episode.append(
                compute_initial_state_value(
                    env.num_states, env.num_actions, algorithm.Q, p_s_given_sa, r_sa))
        s = 0

        V = np.max(algorithm.Q, axis=-1)

        for t in range(iterations):
            a = algorithm.epsilon_greedy(s)
            next_s = np.random.choice(env.num_states, p=p_s_given_sa[s, a])

            algorithm.update_step(s, a, next_s, r_sa[s, a])
            if (s == 9):
                # print(e,found_end)
                break
            s = next_s
            # print(b)
            # if(e != 0):
            #     b -= learning_rate * (0.01 - kl[s,a])
            #     if(b < 0):
            #         b = 0.1
            # print(b)
            adaptive_beta.append(np.log10(b))
    average_return[:,i] += np.array(return_per_episode)
    print(time.time()-start_time)



# plt.imshow(np.max(algorithm.Q, axis=1).reshape(-1, 10))
# plt.show()
episode_axis = np.arange(1, episodes, 20)
mean = np.average(average_return,axis=1)
std = np.std(average_return, axis=1)

np.save('Q_mean',mean)
np.save('Q_std', std)

# np.save('n_{0}_mean'.format(beta),mean)
# np.save('n_{0}_std'.format(beta), std)
# np.save('n_{0}_beta'.format(beta), adaptive_beta)
