import numpy as np
from matplotlib import pyplot as plt
import os
import math




def qlearning_omd(S, actions, K, H, env):
    L = [H]
    while L[-1] < K:
        L.append(int(L[-1] * (1 + 1.0 / H)))
    for i in range(1, len(L)):
        L[i] += L[i - 1]
    episode_rewards = np.zeros(K)


    Q = np.ones((H, S, num_actions ** num_agent))
    for h in range(H):
        Q[h, :, :] = env.R_max * (H - h) * np.ones((S, num_actions ** num_agent))

    V = np.ones((H, S))
    for h in range(H):
        V[h, :] = env.R_max * (H - h) * np.ones(S)

    N = np.zeros((H, S, num_actions ** num_agent))
    N_check = np.zeros((H, S, num_actions ** num_agent))
    r_check = np.zeros((H, S, num_actions ** num_agent))
    v_check = np.zeros((H, S, num_actions ** num_agent))
    
    for k in range(K):         
        env.init_environment()
        state = env.get_current_state()
        if k % 100 == 0:
            print(k)
        for h in range(H):
            # UCB exploration
            x = np.random.uniform(0, 1)
            action = np.argmax(Q[h][state])
            a = []
            if x >= 0.1:
                tmp = action
                for agent in range(num_agent):
                    a.append(tmp % num_actions)
                    tmp //= num_actions
            else:
                for agent in range(num_agent):
                    a.append(np.random.choice(num_actions))
            
            next_state, reward = env.step(a)
            episode_rewards[k] += reward

            r_check[h][state][action] += reward
            if h != H - 1:
                v_check[h][state][action] += V[h + 1][next_state]
            N[h][state][action] += 1
            N_check[h][state][action] += 1

            if N[h][state][action] in L:
                bonus = 0.01 * np.sqrt(H**2 / N_check[h][state][action])
                Q[h][state][action] = min(Q[h][state][action], r_check[h][state][action] / N_check[h][state][action] + v_check[h][state][action] / N_check[h][state][action] + bonus)
                V[h][state] = np.max(Q[h][state])
                N_check[h][state][action] = 0
                r_check[h][state][action] = 0.0
                v_check[h][state][action] = 0.0
            state = next_state

    return episode_rewards




if __name__ == "__main__":
    task = 'goodstatemulti'
    if task == 'boxpushing':
        from env.boxpushing import Environment
        env = Environment(task + '.txt')
        K = 40000
        H = 6
    elif task == 'goodstate':
        from env.goodstate import Environment
        env = Environment()
        K = 5000
        H = 10
    elif task == 'goodstatemulti':
        from env.goodstatemulti import Environment
        env = Environment()
        H = 10
        K = 40000
    
       
    S = env.state_size
    num_agent = env.num_agent
    actions = env.action_size
    num_actions = actions[0]

    for id in range(20):
        file = open(task + '_' + str(id) + 'centr.txt', 'w')
        rewards = qlearning_omd(S, actions, K, H, env)
        for k in range(K):
            file.write(str(rewards[k]) + '\n')
        file.close()
        # plt.plot(range(K), rewards)
        # plt.show()


