import numpy as np
from matplotlib import pyplot as plt
import os
import math
import cvxpy as cp


def transition(s, a, b):
    x = np.random.uniform(0, 1)
    if a == 0 and b == 1:
        if x < prob_threshold:
            return 1
        else:
            return 0
    else:
        if x >= prob_threshold:
            return 1
        else:
            return 0


def initial_state():
    x = np.random.uniform(0, 1)
    if x < 0.5:
        return 0
    else:
        return 1




def projection(b):
    x = cp.Variable(num_actions)
    cost = cp.sum_squares(x - b)
    tmp = np.ones(num_actions)
    prob = cp.Problem(cp.Minimize(cost), [x >= 0, tmp @ x == 1])
    prob.solve()
    return x.value


def dist(a, b):
    return np.sum((a - b) ** 2)


def best_response_mu(b):
    q = R @ b
    ans = np.zeros(num_actions)
    ans[np.argmax(q)] = 1.0
    return ans


def best_response_nu(a):
    q = a.T @ R
    ans = np.zeros(num_actions)
    ans[np.argmax(q)] = 1.0
    return ans



def vlearning_sgd(S, A, B, K, H):
    L = [H]
    while L[-1] < K:
        L.append(int(L[-1] * (1 + 1.0 / H)))
    for i in range(1, len(L)):
        L[i] += L[i - 1]
    episode_rewards = np.zeros(K)

    V = np.ones((H, S))
    for h in range(H):
        V[h, :] = R_max * (H - h) * np.ones(S)

    N = np.zeros((H, S))
    N_check = np.zeros((H, S))
    r_check = np.zeros((H, S))
    v_check = np.zeros((H, S))

    mu = np.ones((H, S, num_actions)) / A
    nu = np.ones((H, S, num_actions)) / B
    mus = []
    nus = []
    visitation = [[[None for s in range(S)] for h in range(H)] for k in range(K)]
    cur_visitation = [[[] for s in range(S)] for h in range(H)]
    zero = np.zeros(num_actions)

    immediate_r = np.zeros(K)
    dual_gap = np.zeros(K)
    cumsum_r = np.zeros(K)
    
    
    for k in range(K): 
        if k % 100 == 0:
            print(k)
            print(mu[0][0])
            print(nu[0][0])
        mus.append(mu.copy())
        nus.append(nu.copy())
        s = initial_state()
        if k > 0:
            for h in range(H):
                for x in range(S):
                    if visitation[k - 1][h][x] is None:
                        visitation[k][h][x] = None
                    else: 
                        visitation[k][h][x] = visitation[k - 1][h][x].copy()
        for h in range(H):
            a = np.random.choice(num_actions, p=mu[h][s])
            b = np.random.choice(num_actions, p=nu[h][s])
            cur_visitation[h][s].append(k)
            next_state= transition(s, a, b)
            reward = R[s][a][b]
            episode_rewards[k] += reward
            immediate_r[k] += mu[h][s].T @ R[s] @ nu[h][s]
            if s == 0:
                dual_gap[k] += dist(mu[h][s], best_response_mu(nu[h][s])) + dist(nu[h][s], best_response_nu(mu[h][s]))
            

            r_check[h][s] += reward
            if h != H - 1:
                v_check[h][s] += V[h + 1][next_state]
            N[h][s] += 1
            N_check[h][s] += 1

            
            loss = R_max * (H - h) - reward
            if h != H - 1:
                loss -= V[h + 1][next_state]

            
            grad_mu = np.zeros(num_actions)
            grad_nu = np.zeros(num_actions)
            grad_mu[a] = loss / (mu[h][s][a] + gamma)
            grad_nu[b] = loss / (nu[h][s][b] + gamma)

            mu[h][s] = np.array(projection(mu[h][s] - step_size * grad_mu))
            nu[h][s] = np.array(projection(nu[h][s] - step_size * grad_nu))
            mu[h][s] = np.maximum(mu[h][s], zero)
            nu[h][s] = np.maximum(nu[h][s], zero)


            if N[h][s] in L:
                bonus = 0.01 * np.sqrt(H**2 * num_actions**3 / N_check[h][s])
                V[h][s] = r_check[h][s] / N_check[h][s] + v_check[h][s] / N_check[h][s] + bonus
                N_check[h][s] = 0
                r_check[h][s] = 0.0
                v_check[h][s] = 0.0
                visitation[k][h][s] = cur_visitation[h][s].copy()
                cur_visitation[h][s] = []
            s = next_state

    return immediate_r, dual_gap, np.array(mus), np.array(nus), visitation


def auxiliary(mus, nus, visitation):
    aux_rewards = np.zeros(K)
    aux_gap = np.zeros(K)
    for k in range(K):
        s = initial_state()
        k_prime = k
        for h in range(H):
            a = np.random.choice(num_actions, p=mus[k_prime][h][s])
            b = np.random.choice(num_actions, p=nus[k_prime][h][s])
            next_state= transition(s, a, b)
            reward = R[s][a][b]

            aux_rewards[k] += mus[k_prime][h][s].T @ R[s] @ nus[k_prime][h][s]
            if s == 0:
                aux_gap[k] += dist(mus[k_prime][h][s], best_response_mu(nus[k_prime][h][s])) + dist(nus[k_prime][h][s], best_response_nu(mus[k_prime][h][s]))
            if visitation[k_prime][h][s] is None:
                k_prime = 0
            else:
                k_prime = np.random.choice(visitation[k_prime][h][s])
            s = next_state
    return aux_rewards, aux_gap



R = np.array([[[-2.0, 5.0], [2.0, -2.0]], [[0.0, 0.0], [0.0, 0.0]]])
R_max = 5.0
prob_threshold = 0.1
K = 5000
H = 10
step_size = 4e-4
gamma = 2e-3
A = 2
B = 2
S = 2
num_actions = A


for id in range(20):
    file = open('markov1_' + str(id) + '.txt', 'w')
    rewards, dual_gap, mus, nus, visitation = vlearning_sgd(S, A, B, K, H)
    aux_rewards, aux_gap = auxiliary(mus, nus, visitation)
    for k in range(K):
        file.write(str(rewards[k]) + '\t' + str(aux_rewards[k]) + '\t' + str(dual_gap[k]) + '\t' + str(aux_gap[k]) + '\n')
    file.close()
    # plt.plot(range(K), rewards)
    # plt.plot(range(K), aux_rewards)
    # plt.show()

