import numpy as np
from matplotlib import pyplot as plt
import os
import cvxpy as cp
import math




def projection(b):
    x = cp.Variable(num_actions)
    cost = cp.sum_squares(x - b)
    tmp = np.ones(num_actions)
    prob = cp.Problem(cp.Minimize(cost), [x >= 0, tmp @ x == 1])
    prob.solve()
    return x.value


def dist(a, b):
    return np.sum((a - b) ** 2)


def best_response_mu(b):
    q = R @ b
    ans = np.zeros(num_actions)
    ans[np.argmax(q)] = 1.0
    return ans


def best_response_nu(a):
    q = a.T @ R
    ans = np.zeros(num_actions)
    ans[np.argmax(q)] = 1.0
    return ans



T = 5000
num_actions = 3
step_size = 2e-3
gamma = 1e-3


R = np.array([[10.0, 0.0, -10.0], [0.0, 2.0, 0.0], [-10.0, 0.0, 10.0]])
prefix = 'matrix1_'
R_max = 10.0


for id in range(20):
    mus = []
    nus = []
    mu = np.ones(num_actions) / num_actions
    mus.append(mu)
    nu = np.ones(num_actions) / num_actions
    nus.append(nu)
    zero = np.zeros(num_actions)


    immediate_r = [mu.T @ R @ nu]
    cumsum_r = [mu.T @ R @ nu]
    dual_gap = [dist(mu, best_response_mu(nu)) + dist(nu, best_response_nu(mu))]
    cumsum_dual_gap = [dist(mu, best_response_mu(nu)) + dist(nu, best_response_nu(mu))]
    file = open(prefix + str(id) + '.txt', 'a')

    for t in range(T):
        if t % 1000 == 0:
            print(t)
        file.write(str(immediate_r[-1]) + '\t' + str(cumsum_r[-1]) + '\t' + str(dual_gap[-1]) + '\t' + str(cumsum_dual_gap[-1]) + '\n')
        a = np.random.choice(num_actions, p=mu)
        b = np.random.choice(num_actions, p=nu)
        reward = R[a, b]
        loss = R_max - reward

        grad_mu = np.zeros(num_actions)
        grad_nu = np.zeros(num_actions)
        grad_mu[a] = loss / (mu[a] + gamma)
        grad_nu[b] = loss / (nu[b] + gamma)

        mu = np.array(projection(mu - step_size * grad_mu))
        nu = np.array(projection(nu - step_size * grad_nu))
        mu = np.maximum(mu, zero)
        nu = np.maximum(nu, zero)
        mus.append(mu)
        nus.append(nu)
        immediate_r.append(mu.T @ R @ nu)
        gap = dist(mu, best_response_mu(nu)) + dist(nu, best_response_nu(mu))

        aver_r = np.average(np.array(immediate_r))
        cumsum_r.append(aver_r)
        dual_gap.append(gap)

        aver_gap = np.average(np.array(dual_gap))
        cumsum_dual_gap.append(aver_gap)
    file.close()
    # plt.plot(range(len(dual_gap)), dual_gap)
    # plt.plot(range(len(immediate_r)), cumsum_dual_gap)
    # plt.show()



