import numpy as np
from matplotlib import pyplot as plt
import os
import math


def transition(s, a, b):
    x = np.random.uniform(0, 1)
    if a == 0 and b == 1:
        if x < prob_threshold:
            return 1
        else:
            return 0
    else:
        if x >= prob_threshold:
            return 1
        else:
            return 0


def initial_state():
    x = np.random.uniform(0, 1)
    if x < 0.5:
        return 0
    else:
        return 1



def get_reward(s, a, b):
    return R[s, a, b]


def qlearning_omd(S, A, B, K, H):
    L = [H]
    while L[-1] < K:
        L.append(int(L[-1] * (1 + 1.0 / H)))
    for i in range(1, len(L)):
        L[i] += L[i - 1]
    episode_rewards = np.zeros(K)

    Q = np.ones((H, S, A * B))
    for h in range(H):
        Q[h, :, :] = R_max * (H - h) * np.ones((S, A * B))

    V = np.ones((H, S))
    for h in range(H):
        V[h, :] = R_max * (H - h) * np.ones(S)

    N = np.zeros((H, S, A * B))
    N_check = np.zeros((H, S, A * B))
    r_check = np.zeros((H, S, A * B))
    v_check = np.zeros((H, S, A * B))
    
    for k in range(K): 
        state = initial_state()
        for h in range(H):
            # UCB exploration
            action = np.argmax(Q[h][state])
            a = action // A
            b = action % A
            next_state= transition(state, a, b)
            reward = R[state][a][b]
            episode_rewards[k] += reward

            r_check[h][state][action] += reward
            if h != H - 1:
                v_check[h][state][action] += V[h + 1][next_state]
            N[h][state][action] += 1
            N_check[h][state][action] += 1

            if N[h][state][action] in L:
                bonus = 0.01 * np.sqrt(H**2 / N_check[h][state][action])
                Q[h][state][action] = min(Q[h][state][action], r_check[h][state][action] / N_check[h][state][action] + v_check[h][state][action] / N_check[h][state][action] + bonus)
                V[h][state] = np.max(Q[h][state])
                N_check[h][state][action] = 0
                r_check[h][state][action] = 0.0
                v_check[h][state][action] = 0.0
            state = next_state

    return episode_rewards



R = np.array([[[-2.0, 5.0], [2.0, -2.0]], [[0.0, 0.0], [0.0, 0.0]]])
R_max = 5.0
prob_threshold = 0.1
K = 5000
H = 10


for id in range(20):
    file = open('markov1_' + str(id) + 'centr.txt', 'a')
    rewards = qlearning_omd(2, 2, 2, K, H)
    for k in range(K):
        file.write(str(rewards[k]) + '\n')
    file.close()
    # plt.plot(range(K), rewards)
    # plt.show()
