import numpy as np
from matplotlib import pyplot as plt
import os
import math


def transition(s, a, b):
    x = np.random.uniform(0, 1)
    if a == 0 and b == 1:
        if x < prob_threshold:
            return 1
        else:
            return 0
    else:
        if x >= prob_threshold:
            return 1
        else:
            return 0


def initial_state():
    x = np.random.uniform(0, 1)
    if x < 0.5:
        return 0
    else:
        return 1



def get_reward(s, a, b):
    return R[s, a, b]


def qlearning_omd(S, A, B, K, H):
    L = [H]
    while L[-1] < K:
        L.append(int(L[-1] * (1 + 1.0 / H)))
    for i in range(1, len(L)):
        L[i] += L[i - 1]
    episode_rewards = np.zeros(K)

    Q = np.ones((2, H, S, A))
    for h in range(H):
        Q[:, h, :, :] = R_max * (H - h) * np.ones((2, S, A))

    V = np.ones((2, H, S))
    for h in range(H):
        V[:, h, :] = R_max * (H - h) * np.ones((2, S))

    N = np.zeros((2, H, S, A))
    N_check = np.zeros((2, H, S, A))
    r_check = np.zeros((2, H, S, A))
    v_check = np.zeros((2, H, S, A))
    
    for k in range(K): 
        state = initial_state()
        for h in range(H):
            # UCB exploration
            a = np.argmax(Q[0][h][state])
            b = np.argmax(Q[1][h][state])    
            x = np.random.uniform(0, 1)
            if x < 0.05:
                a = np.random.choice(A)
                b = np.random.choice(B)
            actions = [a, b]
            next_state= transition(state, a, b)
            reward = R[state][a][b]
            episode_rewards[k] += reward
            
            for agent_id in range(2):
                r_check[agent_id][h][state][actions[agent_id]] += reward
                if h != H - 1:
                    v_check[agent_id][h][state][actions[agent_id]] += V[agent_id][h + 1][next_state]
                N[agent_id][h][state][actions[agent_id]] += 1
                N_check[agent_id][h][state][actions[agent_id]] += 1

            for agent_id in range(2):
                if N[agent_id][h][state][actions[agent_id]] in L:
                    bonus = 0.01 * np.sqrt(H**2 / N_check[agent_id][h][state][actions[agent_id]])
                    Q[agent_id][h][state][actions[agent_id]] = min(Q[agent_id][h][state][actions[agent_id]], r_check[agent_id][h][state][actions[agent_id]] / N_check[agent_id][h][state][actions[agent_id]] + v_check[agent_id][h][state][actions[agent_id]] / N_check[agent_id][h][state][actions[agent_id]] + bonus)
                    V[agent_id][h][state] = np.max(Q[agent_id][h][state])
                    N_check[agent_id][h][state][actions[agent_id]] = 0
                    r_check[agent_id][h][state][actions[agent_id]] = 0.0
                    v_check[agent_id][h][state][actions[agent_id]] = 0.0
            state = next_state

    return episode_rewards



R = np.array([[[-2.0, 5.0], [2.0, -2.0]], [[0.0, 0.0], [0.0, 0.0]]])
R_max = 5.0
prob_threshold = 0.1
K = 5000
H = 10
A = 2
B = 2
S = 2
for id in range(20):
    file = open('markov1_' + str(id) + 'indep.txt', 'a')
    rewards = qlearning_omd(S, A, B, K, H)
    for k in range(K):
        file.write(str(rewards[k]) + '\n')
    file.close()
    # plt.plot(range(K), rewards)
    # plt.show()
