import numpy as np
from matplotlib import pyplot as plt
import os
import math
import cvxpy as cp
import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--testid', help="index of test", type=int, default=0)
parser.add_argument('--task', help="task name", default='goodstate')
args = parser.parse_args()





def dist(a, b):
    return np.sum((a - b) ** 2)


def best_response_mu(b, R):
    q = R @ b
    ans = np.zeros(num_actions)
    ans[np.argmax(q)] = 1.0
    return ans



def projection(y):
    a = np.ones(num_actions)
    l = y / a
    idx = np.argsort(l)
    d = len(l)
    evalpL = lambda k: np.sum(a[idx[k : ]] * (y[idx[k : ]] - l[idx[k]] * a[idx[k : ]])) -1

    def bisectsearch():
        idxL, idxH = 0, d - 1
        L = evalpL(idxL)
        if L < 0:
            return idxL
        while (idxH - idxL)>1:
            iMid = int((idxL + idxH) / 2)
            M = evalpL(iMid)
            if M > 0:
                idxL, L = iMid, M
            else:
                idxH = iMid

        return idxH

    k = bisectsearch()
    lam = (np.sum(a[idx[k : ]] * y[idx[k : ]]) - 1) / np.sum(a[idx[k : ]])
    x = np.maximum(0, y - lam * a)

    return x


def best_response_nu(a, R):
    q = a.T @ R
    ans = np.zeros(num_actions)
    ans[np.argmax(q)] = 1.0
    return ans



def vlearning_sgd(S, actions, K, H, env):
    L = [H]
    while L[-1] < K:
        L.append(int(L[-1] * (1 + 1.0 / H)))
    for i in range(1, len(L)):
        L[i] += L[i - 1]
    episode_rewards = np.zeros(K)

    n = len(actions)
    V = np.ones((H, S))
    V0 = np.zeros((S, H, S))
    for h in range(H):
        V[h, :] = env.R_max * (H - h) * np.ones(S)

    N = np.zeros((H, S))
    N_check = np.zeros((H, S))
    r_check = np.zeros((H, S))
    v_check = np.zeros((H, S))

    pi = np.ones((n, H, S, num_actions)) / num_actions
    pis = [[] for agent in range(n)]
    visitation = [[[None for s in range(S)] for h in range(H)] for k in range(K)]
    cur_visitation = [[[] for s in range(S)] for h in range(H)]
    zero = np.zeros(num_actions)

    immediate_r = np.zeros(K)
    dual_gap = np.zeros(K)
    cumsum_r = np.zeros(K)
    
    
    for k in range(K): 
        if k % 100 == 0:
            print(k)
        for agent in range(n):
            pis[agent].append(pi[agent].copy())

        env.init_environment()
        s = env.get_current_state()
        if k > 0:
            for h in range(H):
                for x in range(S):
                    if visitation[k - 1][h][x] is None:
                        visitation[k][h][x] = None
                    else: 
                        visitation[k][h][x] = visitation[k - 1][h][x].copy()
        for h in range(H):
            N[h][s] += 1
            N_check[h][s] += 1
            if N_check[h][s] == 1 and h != H - 1:
                for s_prime in range(S):
                    V0[s][h][s_prime] = V[h + 1][s_prime]

            
            eps_pi = []
            for agent in range(n):
                eps_pi.append((1 - theta) * pi[agent][h][s] + theta * np.ones(num_actions) / num_actions)
            a = []
            for agent in range(n):
                a.append(np.random.choice(num_actions, p=eps_pi[agent]))
            cur_visitation[h][s].append(k)

            next_state, reward = env.step(a)

            episode_rewards[k] += reward
            immediate_r[k] += env.get_reward(pi[:, h, s])


            r_check[h][s] += reward
            if h != H - 1:
                v_check[h][s] += V0[s][h][next_state]

            
            loss = env.R_max * (H - h) - reward
            if h != H - 1:
                loss -= V0[s][h][next_state]

            
            grad_pi = np.zeros((n, num_actions))
            for agent in range(n):
                grad_pi[agent][a[agent]] = loss / eps_pi[agent][a[agent]]

            for agent in range(n):
                pi[agent][h][s] = np.array(projection(pi[agent][h][s] - step_size * grad_pi[agent]))
                pi[agent][h][s] = np.maximum(pi[agent][h][s], zero)


            if N[h][s] in L:
                bonus = 0.01 * np.sqrt(H**2 * num_actions**3 / N_check[h][s])
                V[h][s] = r_check[h][s] / N_check[h][s] + v_check[h][s] / N_check[h][s] + bonus
                N_check[h][s] = 0
                r_check[h][s] = 0.0
                v_check[h][s] = 0.0
                visitation[k][h][s] = cur_visitation[h][s].copy()
                cur_visitation[h][s] = []
            s = next_state
        

    return immediate_r, np.array(pis), visitation


def auxiliary(pis, visitation, env):
    aux_rewards = np.zeros(K)
    aux_gap = np.zeros(K)
    for k in range(K):
        env.init_environment()
        s = env.get_current_state()
        k_prime = k
        for h in range(H):
            a = []
            for agent in range(num_agent):
                a.append(np.random.choice(num_actions, p=pis[agent][k_prime][h][s]))
            
            next_state, reward = env.step(a)

            aux_rewards[k] += env.get_reward(pis[:, k_prime, h, s])
            if visitation[k_prime][h][s] is None:
                k_prime = 0
            else:
                k_prime = np.random.choice(visitation[k_prime][h][s])
            s = next_state
    return aux_rewards


if __name__ == "__main__":
    id = args.testid
    task = args.task
    
    if task == 'boxpushing':
        from env.boxpushing import Environment
        env = Environment(task + '.txt')
        H = 6
        K = 40000
        step_size = 2e-5
        gamma = 2e-3
        theta = 1e-2
    elif task == 'goodstate':
        from env.goodstate import Environment
        env = Environment()
        H = 10
        K = 5000
        step_size = 4e-4
        gamma = 2e-3
        theta = 1e-2
    elif task == 'goodstatemulti':
        from env.goodstatemulti import Environment
        env = Environment()
        H = 10
        K = 40000
        step_size = 6e-5
        gamma = 2e-3
        theta = 1e-2
    
    

    S = env.state_size
    num_agent = env.num_agent
    actions = env.action_size
    num_actions = actions[0]

    file = open(task + '_' + str(id) + '.txt', 'w')
    rewards, pis, visitation = vlearning_sgd(S, actions, K, H, env)
    aux_rewards = auxiliary(pis, visitation, env)
    for k in range(K):
        file.write(str(rewards[k]) + '\t' + str(aux_rewards[k]) + '\n')
    file.close()
