from congestion_games import *
import matplotlib.pyplot as plt
import itertools
import numpy as np
import copy
import statistics
import seaborn as sns; sns.set()
from time import process_time
from tqdm import tqdm
import random

SEED = 43
random.seed(SEED)
np.random.seed(SEED)
myp_start = process_time()

def projection_simplex_sort(v, z=1):
	# Courtesy: EdwardRaff/projection_simplex.py
    if v.sum() == z and np.alltrue(v >= 0):
        return v
    n_features = v.shape[0]
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u) - z
    ind = np.arange(n_features) + 1
    cond = u - cssv / ind > 0
    rho = ind[cond][-1]
    theta = cssv[cond][-1] / float(rho)
    w = np.maximum(v - theta, 0)
    return w

# Define the states and some necessary info
N = 8 #number of agents
harm = - 100 * N # pentalty for being in bad state

safe_state = CongGame(N,1,[[1,0],[2,0],[4,0],[6,0]])
bad_state = CongGame(N,1,[[1,-100],[2,-100],[4,-100],[6,-100]])
state_dic = {0: safe_state, 1: bad_state}

M = safe_state.num_actions 
D = safe_state.m #number facilities
S = 2

# Dictionary to store the action profiles and rewards to
selected_profiles = {}

# Dictionary associating each action (value) to an integer (key)
act_dic = {}
counter = 0
for act in safe_state.actions:
	act_dic[counter] = act 
	counter += 1

def get_next_state(state, actions):
    acts_from_ints = [act_dic[i] for i in actions]
    density = state_dic[state].get_counts(acts_from_ints)
    max_density = max(density)

    if state == 0 and max_density > N/2 or state == 1 and max_density > N/4:
      # if state == 0 and max_density > N/2 and np.random.uniform() > 0.2 or state == 1 and max_density > N/4 and np.random.uniform() > 0.1:
        return 1
    return 0

def pick_action(prob_dist):
    # np.random.choice(range(len(prob_dist)), 1, p = prob_dist)[0]
    acts = [i for i in range(len(prob_dist))]
    action = np.random.choice(acts, 1, p = prob_dist)
    return action[0]

def visit_dist(state, policy, gamma, T,samples):
    # This is the unnormalized visitation distribution. Since we take finite trajectories, the normalization constant is (1-gamma**T)/(1-gamma).
    visit_states = {st: np.zeros(T) for st in range(S)}        
    for i in range(samples):
        curr_state = state
        for t in range(T):
            visit_states[curr_state][t] += 1
            actions = [pick_action(policy[curr_state, i]) for i in range(N)]
            curr_state = get_next_state(curr_state, actions)
    dist = [np.dot(v/samples,gamma**np.arange(T)) for (k,v) in visit_states.items()]
    return dist 

def generate_visit_buffer(state, policy, T, samples):
    
    buffer = []
    for i in range(samples):
        curr_state = state
        for t in range(T):
            buffer.append((curr_state, t))
            actions = [pick_action(policy[curr_state, i]) for i in range(N)]
            curr_state = get_next_state(curr_state, actions)
    return buffer

def visit_dist_from_buffer(state, policy, gamma, T, samples, buffer=None):
    
    if buffer is None:
        buffer = generate_visit_buffer(state, policy, T, samples)

    visit_states = {st: np.zeros(T) for st in range(S)}

    
    for st, t in buffer:
        visit_states[st][t] += 1

    
    dist = [np.dot(v/samples, gamma**np.arange(T)) for (k, v) in visit_states.items()]
    return dist

def value_function(policy, gamma, T,samples):
    value_fun = {(s,i):0 for s in range(S) for i in range(N)}
    for k in range(samples):
        for state in range(S):
            curr_state = state
            for t in range(T):
                actions = [pick_action(policy[curr_state, i]) for i in range(N)]
                q = tuple(actions+[curr_state])
                rewards = selected_profiles.setdefault(q,get_reward(state_dic[curr_state], [act_dic[i] for i in actions]))                  
                for i in range(N):
                    value_fun[state,i] += (gamma**t)*rewards[i]
                curr_state = get_next_state(curr_state, actions)
    value_fun.update((x,v/samples) for (x,v) in value_fun.items())
    return value_fun

def generate_V_buffer(policy, gamma, T, samples):
    buffer = []  
    for k in range(samples):
        for state in range(S):
            curr_state = state
            for t in range(T):
                actions = [pick_action(policy[curr_state, i]) for i in range(N)]
                q = tuple(actions + [curr_state])
                rewards = selected_profiles.setdefault(
                    q, get_reward(state_dic[curr_state], [act_dic[i] for i in actions])
                )
                buffer.append((curr_state, actions, rewards, t))
                curr_state = get_next_state(curr_state, actions)
    return buffer

def value_from_buffer(policy, gamma, T, samples, buffer=None):
    if buffer == None:
        buffer = generate_V_buffer(policy, gamma, T, samples)
    value_fun = {(s, i): 0 for s in range(S) for i in range(N)}

    for state, actions, rewards, t in buffer:
        for i in range(N):
            value_fun[state, i] += (gamma**t) * rewards[i]

    
    for s in range(S):
        for i in range(N):
            value_fun[s, i] /= samples

    return value_fun


def Q_function(agent, state, action, policy, gamma, value_fun, samples):
    tot_reward = 0
    for i in range(samples):
        actions = [pick_action(policy[state, i]) for i in range(N)]
        actions[agent] = action
        q = tuple(actions+[state])
        rewards = selected_profiles.setdefault(q,get_reward(state_dic[state], [act_dic[i] for i in actions]))
        tot_reward += rewards[agent] + gamma*value_fun[get_next_state(state, actions), agent]
    return (tot_reward / samples)

def generate_Q_buffer(agent, state, action, policy, samples):
    buffer = [] 
    for i in range(samples):
        actions = [pick_action(policy[state, j]) for j in range(N)]
        actions[agent] = action  
        q = tuple(actions + [state])
        rewards = selected_profiles.setdefault(
            q, get_reward(state_dic[state], [act_dic[j] for j in actions])
        )
        next_state = get_next_state(state, actions)
        buffer.append((next_state, rewards[agent]))
    return buffer

def Q_function_frombuffer(agent, state, action, policy, gamma, value_fun, samples, buffer=None):
    if buffer == None:
        buffer = generate_Q_buffer(agent, state, action, policy, samples)
    tot_reward = 0
    for next_state, reward in buffer:
        tot_reward += reward + gamma * value_fun[next_state, agent]
    return tot_reward / len(buffer)




def policy_accuracy(policy_pi, policy_star):
    total_dif = N * [0]
    for agent in range(N):
        for state in range(S):
            total_dif[agent] += np.sum(np.abs((policy_pi[state, agent] - policy_star[state, agent])))
	  # total_dif[agent] += np.sqrt(np.sum((policy_pi[state, agent] - policy_star[state, agent])**2))
    return np.sum(total_dif) / N


def Natural_PG(policy, gradient):
    gradient -= np.mean(gradient)
    new_policy = policy*np.exp(gradient)
    new_policy/=np.sum(new_policy)
    return new_policy
    
def Natural_PG_entropy(policy, gradient, tau):
    gradient -= np.mean(gradient)
    new_policy = np.power(policy, tau)*np.exp(gradient)
    new_policy/=np.sum(new_policy)
    return new_policy

def PG_update(policy, gradient):
    # gradient -= np.mean(gradient)
    new_policy = policy*np.exp(gradient)
    new_policy/=np.sum(new_policy)
    return new_policy


def NPG_policy_gradient(mu, max_iters, gamma, eta, T, samples, K):

    policy = {(s,i): [1/M]*M for s in range(S) for i in range(N)}
    policy_hist = [copy.deepcopy(policy)]

    for t in tqdm(range(max_iters)):

        #print(t)

        b_dist = M * [0]
        for st in range(S):
            a_dist = visit_dist(st, policy, gamma, T, samples)

            b_dist[st] = np.dot(a_dist, mu)
            
        grads = np.zeros((N, S, M))
        Q_bar = np.zeros((N, S, M))
        value_fun = value_from_buffer(policy, gamma, T, samples)
	
        for agent in range(N):
            for st in range(S):
                for act in range(M):
                    grads[agent, st, act] = b_dist[st] * Q_function_frombuffer(agent, st, act, policy, gamma, value_fun, samples)
                    Q_bar[agent, st, act] =  Q_function_frombuffer(agent, st, act, policy, gamma, value_fun, samples)

        for agent in range(N):
            for st in range(S):
#                 policy[st, agent] = projection_simplex_sort(np.add(policy[st, agent], eta * grads[agent,st]), z=1)
                # policy[st, agent] = projection_simplex_sort(np.add(policy[st, agent], eta * Q_bar[agent,st]), z=1)
                # policy[st, agent] = Natural_PG(policy[st, agent], eta * grads[agent,st])
                policy[st, agent] = Natural_PG(policy[st, agent], eta * (Q_bar[agent,st] - np.dot(Q_bar[agent,st], policy[st, agent])))
                # policy[st, agent] = Natural_PG_entropy(policy[st, agent], eta/(1-gamma) * Q_bar[agent,st], eta*0.1)
        policy_hist.append(copy.deepcopy(policy))

        if policy_accuracy(policy_hist[t], policy_hist[t-1]) < 10e-16:
      # if policy_accuracy(policy_hist[t+1], policy_hist[t]) < 10e-16: (it makes a difference, not when t=0 but from t=1 onwards.)
            return policy_hist

    return policy_hist


def NPG_policy_gradient_buffer(mu, max_iters, gamma, eta, T, samples, K, X, clip=3):
    """
    NPG using buffers for value function and Q function.
    Every K iterations, refresh buffer and value function.
    """
    policy = {(s, i): np.ones(M) / M for s in range(S) for i in range(N)}
    policy_hist = [copy.deepcopy(policy)]

    value_buffer = []  # buffer for value function
    Q_buffer = []      # buffer for Q function
    value_fun = {(s, i): 0 for s in range(S) for i in range(N)}

    for t in tqdm(range(max_iters)):

        
        if t % K == 0:
            policy_old = copy.deepcopy(policy)
            value_buffer = generate_V_buffer(policy, gamma,  T, samples)
            Q_buffer = [[[None for _ in range(M)] for _ in range(S)] for _ in range(N)]
            for agent in range(N):
                for state in range(S):
                    for action in range(M):
                        Q_buffer[agent][state][action] = generate_Q_buffer(agent, state, action, policy, samples)
            value_fun = value_from_buffer(policy, gamma, T, samples, buffer=value_buffer)

            Q_bar_last = np.zeros((N, S, M))
            for agent in range(N):
                for st in range(S):
                    for act in range(M):
                        Q_bar_last[agent, st, act] = Q_function_frombuffer(agent, None, None, policy, gamma, value_fun, samples,
                                                    buffer=Q_buffer[agent][st][act])


        
        
        if t == 0:
            b_dist = M * [0]
            for st in range(S):
                a_dist = visit_dist(st, policy, gamma, T, samples)
                #a_dist = visit_dist_from_buffer(st, policy, gamma, T, samples, buffer=visit_buffer[st])

                b_dist[st] = np.dot(a_dist, mu)

        
        grads = np.zeros((N, S, M))
        Q_bar = np.zeros((N, S, M))
        #print("t:", t)
        for agent in range(N):
            for st in range(S):
                for act in range(M):
                    Q_raw = Q_function_frombuffer(agent, None, None, policy, gamma, value_fun, samples,
                                                  buffer=Q_buffer[agent][st][act])
                    pi_new = policy[st, agent][act]
                    pi_old = policy_old[st, agent][act]
                    IS_weight = pi_new / (pi_old + 1e-10)
                    IS_weight = min(IS_weight, clip)  # 可选 clip
                    
                    
                    Q_bar[agent, st, act] = Q_raw * IS_weight
                    grads[agent, st, act] = b_dist[st] * Q_bar[agent, st, act]

        if t % X == 0:
            
            for agent in range(N):
                for st in range(S):
                    policy_old[st, agent] = Natural_PG(policy_old[st, agent],
                                                    eta  * (Q_bar_last[agent, st] - 
                                                            np.dot(Q_bar_last[agent, st], policy_old[st, agent])))
            b_dist = M * [0]
            for st in range(S):
                a_dist = visit_dist(st, policy_old, gamma, T, samples)
                #a_dist = visit_dist_from_buffer(st, policy, gamma, T, samples, buffer=visit_buffer[st])

                b_dist[st] = np.dot(a_dist, mu)
            value_buffer = generate_V_buffer(policy_old, gamma,  T, samples)
            #visit_buffer = [None for _ in range(S)]
            #visit_buffer = [generate_visit_buffer(st, policy_old, T, samples) for st in range(S)]
            Q_buffer = [[[None for _ in range(M)] for _ in range(S)] for _ in range(N)]
            for agent in range(N):
                for state in range(S):
                    for action in range(M):
                        Q_buffer[agent][state][action] = generate_Q_buffer(agent, state, action, policy_old, samples)
            value_fun = value_from_buffer(policy_old, gamma, T, samples, buffer=value_buffer)

        
        for agent in range(N):
            for st in range(S):
                policy[st, agent] = Natural_PG(policy[st, agent], eta * (Q_bar[agent,st] - np.dot(Q_bar[agent,st], policy[st, agent])))
                # policy[st, agent] = Natural_PG_entropy(policy[st, agent], eta/(1-gamma) * Q_bar[agent,st], eta*0.1)
        policy_hist.append(copy.deepcopy(policy))

        
        

    return policy_hist


def NPG_policy_gradient_buffer_noIS(mu, max_iters, gamma, eta, T, samples, K, X, clip=3):
    """
    NPG using buffers for value function and Q function.
    Every K iterations, refresh buffer and value function.
    """
    policy = {(s, i): np.ones(M) / M for s in range(S) for i in range(N)}
    policy_hist = [copy.deepcopy(policy)]

    value_buffer = []  # buffer for value function
    Q_buffer = []      # buffer for Q function
    value_fun = {(s, i): 0 for s in range(S) for i in range(N)}

    for t in tqdm(range(max_iters)):

        
        if t % K == 0:
            policy_old = copy.deepcopy(policy)
            value_buffer = generate_V_buffer(policy, gamma,  T, samples)
            Q_buffer = [[[None for _ in range(M)] for _ in range(S)] for _ in range(N)]
            for agent in range(N):
                for state in range(S):
                    for action in range(M):
                        Q_buffer[agent][state][action] = generate_Q_buffer(agent, state, action, policy, samples)
            value_fun = value_from_buffer(policy, gamma, T, samples, buffer=value_buffer)

            Q_bar_last = np.zeros((N, S, M))
            for agent in range(N):
                for st in range(S):
                    for act in range(M):
                        Q_bar_last[agent, st, act] = Q_function_frombuffer(agent, None, None, policy, gamma, value_fun, samples,
                                                    buffer=Q_buffer[agent][st][act])


        
        # b_dist
        if t == 0:
            b_dist = M * [0]
            for st in range(S):
                a_dist = visit_dist(st, policy, gamma, T, samples)
                #a_dist = visit_dist_from_buffer(st, policy, gamma, T, samples, buffer=visit_buffer[st])

                b_dist[st] = np.dot(a_dist, mu)

        
        grads = np.zeros((N, S, M))
        Q_bar = np.zeros((N, S, M))
        #print("t:", t)
        for agent in range(N):
            for st in range(S):
                for act in range(M):
                    Q_raw = Q_function_frombuffer(agent, None, None, policy, gamma, value_fun, samples,
                                                  buffer=Q_buffer[agent][st][act])
                    pi_new = policy[st, agent][act]
                    pi_old = policy_old[st, agent][act]
                    IS_weight = 1
                    IS_weight = min(IS_weight, clip)  
                    
                    
                    Q_bar[agent, st, act] = Q_raw * IS_weight
                    grads[agent, st, act] = b_dist[st] * Q_bar[agent, st, act]

        if t % X == 0:
            
            for agent in range(N):
                for st in range(S):
                    policy_old[st, agent] = Natural_PG(policy_old[st, agent],
                                                    eta  * (Q_bar_last[agent, st] - 
                                                            np.dot(Q_bar_last[agent, st], policy_old[st, agent])))
            b_dist = M * [0]
            for st in range(S):
                a_dist = visit_dist(st, policy_old, gamma, T, samples)
                #a_dist = visit_dist_from_buffer(st, policy, gamma, T, samples, buffer=visit_buffer[st])

                b_dist[st] = np.dot(a_dist, mu)
            value_buffer = generate_V_buffer(policy_old, gamma,  T, samples)
            #visit_buffer = [None for _ in range(S)]
            #visit_buffer = [generate_visit_buffer(st, policy_old, T, samples) for st in range(S)]
            Q_buffer = [[[None for _ in range(M)] for _ in range(S)] for _ in range(N)]
            for agent in range(N):
                for state in range(S):
                    for action in range(M):
                        Q_buffer[agent][state][action] = generate_Q_buffer(agent, state, action, policy_old, samples)
            value_fun = value_from_buffer(policy_old, gamma, T, samples, buffer=value_buffer)

        
        for agent in range(N):
            for st in range(S):
#                 policy[st, agent] = projection_simplex_sort(np.add(policy[st, agent], eta * grads[agent,st]), z=1)
                # policy[st, agent] = projection_simplex_sort(np.add(policy[st, agent], eta * Q_bar[agent,st]), z=1)
                # policy[st, agent] = Natural_PG(policy[st, agent], eta * grads[agent,st])
                policy[st, agent] = Natural_PG(policy[st, agent], eta * (Q_bar[agent,st] - np.dot(Q_bar[agent,st], policy[st, agent])))
                # policy[st, agent] = Natural_PG_entropy(policy[st, agent], eta/(1-gamma) * Q_bar[agent,st], eta*0.1)
        policy_hist.append(copy.deepcopy(policy))

        
        

    return policy_hist


def NPG_policy_gradient_IS_value(mu, max_iters, gamma, eta, T, samples, K):
    """
    NPG with importance sampling using value function.
    Every K steps, collect fresh samples to update value_fun.
    """
    policy = {(s, i): np.ones(M) / M for s in range(S) for i in range(N)}
    policy_hist = [copy.deepcopy(policy)]
    buffer = []  # store past samples

    for t in tqdm(range(max_iters)):

        
        if t % K == 0:
            buffer = []
            for s in range(S):
                for _ in range(samples):
                    curr_state = s
                    traj = []
                    for step in range(T):
                        actions = [pick_action(policy[curr_state, i]) for i in range(N)]
                        rewards = selected_profiles.setdefault(
                            tuple(actions + [curr_state]),
                            get_reward(state_dic[curr_state], [act_dic[i] for i in actions])
                        )
                        traj.append((curr_state, actions.copy(), rewards.copy()))
                        curr_state = get_next_state(curr_state, actions)
                    buffer.extend(traj)

            
            value_fun = {(s, i): 0 for s in range(S) for i in range(N)}
            for s, actions, rewards in buffer:
                for agent in range(N):
                    value_fun[s, agent] += rewards[agent]
            value_fun.update((k, v / samples) for k, v in value_fun.items())

        # b_dist
        b_dist = np.zeros(S)
        for st in range(S):
            count = sum(1 for (s, _, _) in buffer if s == st)
            b_dist[st] = mu[st] * (count / len(buffer))

        # Q_bar 和 grads
        Q_bar = np.zeros((N, S, M))
        grads = np.zeros((N, S, M))

        for agent in range(N):
            for st in range(S):
                for act in range(M):
                    Q_sum = 0
                    grad_sum = 0
                    for s, actions, rewards in buffer:
                        if s != st:
                            continue
                        old_prob = policy_hist[-K if K <= len(policy_hist)-1 else -1][st, agent][actions[agent]]
                        w = policy[st, agent][act] / (old_prob + 1e-12)  # IS 权重

                        actions_copy = actions.copy()
                        actions_copy[agent] = act
                        next_state = get_next_state(st, actions_copy)
                        Q_val = rewards[agent] + gamma * value_fun[next_state, agent]

                        grad_sum += w * Q_val
                        Q_sum += Q_val

                    Q_bar[agent, st, act] = Q_sum / samples
                    grads[agent, st, act] = b_dist[st] * grad_sum / samples

        
        for agent in range(N):
            for st in range(S):
                advantage = Q_bar[agent, st] - np.dot(Q_bar[agent, st], policy[st, agent])
                policy[st, agent] = Natural_PG(policy[st, agent], eta * advantage)

        policy_hist.append(copy.deepcopy(policy))

        
        # if policy_accuracy(policy_hist[t], policy_hist[t-1]) < 1e-16:
        #     return policy_hist

    return policy_hist



def PG_policy_gradient(mu, max_iters, gamma, eta, T, samples):

    policy = {(s,i): [1/M]*M for s in range(S) for i in range(N)}
    policy_hist = [copy.deepcopy(policy)]

    for t in tqdm(range(max_iters)):

        #print(t)

        b_dist = M * [0]
        for st in range(S):
            a_dist = visit_dist(st, policy, gamma, T, samples)

            b_dist[st] = np.dot(a_dist, mu)
            
        grads = np.zeros((N, S, M))
        theta_grads = np.zeros((N, S, M))
        Q_bar = np.zeros((N, S, M))
        value_fun = value_function(policy, gamma, T, samples)
	
        for agent in range(N):
            for st in range(S):
                for act in range(M):
                    Q_bar[agent, st, act] =  Q_function(agent, st, act, policy, gamma, value_fun, samples)
                    grads[agent, st, act] = b_dist[st] * Q_bar[agent, st, act]
                    
                theta_grads[agent, st] = b_dist[st] * (Q_bar[agent,st] - np.dot(Q_bar[agent,st], policy[st, agent]))

        for agent in range(N):
            for st in range(S):
                # policy[st, agent] = projection_simplex_sort(np.add(policy[st, agent], eta * grads[agent,st]), z=1)
                # policy[st, agent] = projection_simplex_sort(np.add(policy[st, agent], eta * Q_bar[agent,st]), z=1)
#                 policy[st, agent] = Natural_PG(policy[st, agent], eta * grads[agent,st])
                policy[st, agent] = PG_update(policy[st, agent], eta *theta_grads[agent,st])
                # policy[st, agent] = Natural_PG(policy[st, agent], eta * Q_bar[agent,st], eta*0.1)
        policy_hist.append(copy.deepcopy(policy))

        if policy_accuracy(policy_hist[t], policy_hist[t-1]) < 10e-16:
      # if policy_accuracy(policy_hist[t+1], policy_hist[t]) < 10e-16: (it makes a difference, not when t=0 but from t=1 onwards.)
            return policy_hist

    return policy_hist

def PGA_policy_gradient(mu, max_iters, gamma, eta, T, samples):

    policy = {(s,i): [1/M]*M for s in range(S) for i in range(N)}
    policy_hist = [copy.deepcopy(policy)]

    for t in tqdm(range(max_iters)):

        #print(t)

        b_dist = M * [0]
        for st in range(S):
            a_dist = visit_dist(st, policy, gamma, T, samples)

            b_dist[st] = np.dot(a_dist, mu)
            
        grads = np.zeros((N, S, M))
        Q_bar = np.zeros((N, S, M))
        value_fun = value_function(policy, gamma, T, samples)
	
        for agent in range(N):
            for st in range(S):
                for act in range(M):
                    grads[agent, st, act] = b_dist[st] * Q_function(agent, st, act, policy, gamma, value_fun, samples)
                    Q_bar[agent, st, act] =  Q_function(agent, st, act, policy, gamma, value_fun, samples)

        for agent in range(N):
            for st in range(S):
                policy[st, agent] = projection_simplex_sort(np.add(policy[st, agent], eta * grads[agent,st]), z=1)

        policy_hist.append(copy.deepcopy(policy))

    #     if policy_accuracy(policy_hist[t], policy_hist[t-1]) < 10e-14:
    #   # if policy_accuracy(policy_hist[t+1], policy_hist[t]) < 10e-16: (it makes a difference, not when t=0 but from t=1 onwards.)
    #         return policy_hist

    return policy_hist

def PQA_policy_gradient(mu, max_iters, gamma, eta, T, samples):

    policy = {(s,i): [1/M]*M for s in range(S) for i in range(N)}
    policy_hist = [copy.deepcopy(policy)]

    for t in tqdm(range(max_iters)):

        #print(t)

        b_dist = M * [0]
        for st in range(S):
            a_dist = visit_dist(st, policy, gamma, T, samples)

            b_dist[st] = np.dot(a_dist, mu)
            
        grads = np.zeros((N, S, M))
        Q_bar = np.zeros((N, S, M))
        value_fun = value_function(policy, gamma, T, samples)
	
        for agent in range(N):
            for st in range(S):
                for act in range(M):
                    grads[agent, st, act] = b_dist[st] * Q_function(agent, st, act, policy, gamma, value_fun, samples)
                    Q_bar[agent, st, act] =  Q_function(agent, st, act, policy, gamma, value_fun, samples)

        for agent in range(N):
            for st in range(S):
                policy[st, agent] = projection_simplex_sort(np.add(policy[st, agent], eta * Q_bar[agent,st]), z=1)

        policy_hist.append(copy.deepcopy(policy))

    #     if policy_accuracy(policy_hist[t], policy_hist[t-1]) < 10e-16:
    #   # if policy_accuracy(policy_hist[t+1], policy_hist[t]) < 10e-16: (it makes a difference, not when t=0 but from t=1 onwards.)
    #         return policy_hist

    return policy_hist

def evaluate_policy(policy, gamma, T, samples):
    """
    Evaluate a given policy by estimating the expected discounted reward for each agent.

    Args:
        policy: dict, current policy {(state, agent): probability distribution over actions}
        gamma: float, discount factor
        T: int, trajectory length
        samples: int, number of trajectories per state

    Returns:
        avg_rewards: np.array of shape (N,), average expected reward per agent
    """
    total_rewards = np.zeros(N)
    
    state = 0
    for _ in range(samples):
        curr_state = state
        for t in range(T):
                # Sample actions for all agents according to current policy
            actions = [pick_action(policy[curr_state, i]) for i in range(N)]
                # Get rewards for this action profile
            rewards = selected_profiles.setdefault(
                tuple(actions + [curr_state]),
                get_reward(state_dic[curr_state], [act_dic[i] for i in actions])
            )
                # Discounted reward accumulation
            total_rewards += (gamma ** t) * np.array(rewards)
                # Move to next state
            curr_state = get_next_state(curr_state, actions)
    
    avg_rewards = total_rewards / (samples * S)
    return avg_rewards

def nash_gap(policy, gamma, T, samples):
    
    base_value = value_from_buffer(policy, gamma, T, samples) 
    
    max_gap = 0.0
    for i in range(N):  
        best_dev_value = -np.inf
        
        
        for a0 in range(len(act_dic)):
            for a1 in range(len(act_dic)):
                dev_policy = policy.copy()
                
                
                dev_policy[(0, i)] = np.zeros(len(act_dic))
                dev_policy[(0, i)][a0] = 1.0
                dev_policy[(1, i)] = np.zeros(len(act_dic))
                dev_policy[(1, i)][a1] = 1.0
                dev_value = value_from_buffer(dev_policy, gamma, T, samples)
                
                
                avg_dev_value = np.mean([dev_value[s, i] for s in range(S)])
                avg_base_value = np.mean([base_value[s, i] for s in range(S)])
                
                best_dev_value = max(best_dev_value, avg_dev_value - avg_base_value)
        
        max_gap = max(max_gap, best_dev_value)
    
    return max_gap


def get_accuracies(policy_hist):
    fin = policy_hist[-1]
    accuracies = []
    for i in range(len(policy_hist)):
        this_acc = policy_accuracy(policy_hist[i], fin)
        accuracies.append(this_acc)
    return accuracies

def full_experiment(runs,iters,eta,T,samples, algorithm):


    densities = np.zeros((S,M))
    all_rewards = []
    raw_accuracies = []
    for k in (range(runs)):
        if algorithm == "NPG_noIS":
            policy_hist = NPG_policy_gradient_buffer_noIS([0.5, 0.5],iters,0.99,eta,T,samples, K=30, X=100000, clip=3)
        elif algorithm == "NPG_buffer":
            policy_hist = NPG_policy_gradient_buffer([0.5, 0.5],iters,0.99,eta,T,samples, K=30, X=100000, clip=3)
        elif algorithm == "NPG_BPP":
            policy_hist = NPG_policy_gradient_buffer([0.5, 0.5],iters,0.99,eta,T,samples, K=30)
        elif algorithm == "PQA":
            policy_hist = PGA_policy_gradient([0.5, 0.5],iters,0.99,eta,T,samples)
        else:
            raise NotImplementedError

#         policy_hist = policy_gradient([0.5, 0.5],iters,0.99,eta,T,samples)
        raw_accuracies.append(get_accuracies(policy_hist))

        converged_policy = policy_hist[-1]
        for i in range(N):
            for s in range(S):
                densities[s] += converged_policy[s,i]
        rewards_per_iter = []
        nash_gap_c = None
        for t, policy_iter in enumerate(policy_hist):
            if t % 10 == 0: 
                print("Iteration {}".format(t))
                avg_rewards = evaluate_policy(policy_iter, gamma=0.99, T=T, samples=100)
                #nash_gap_c = nash_gap(policy_iter, gamma=0.99, T=T, samples=20)
                #rewards_per_iter.append(avg_rewards)
                rewards_per_iter.append(np.mean(avg_rewards))
        # rewards_trend_runs.append(rewards_per_iter)
        all_rewards.append(rewards_per_iter)

    densities = densities / runs
    all_rewards = np.array(all_rewards)   # shape = (runs, num_points)

    
    mean_rewards = np.mean(all_rewards, axis=0)
    std_rewards = np.std(all_rewards, axis=0)
    max_rewards = np.max(all_rewards, axis=0)
    min_rewards = np.min(all_rewards, axis=0)

    
    iterations = np.arange(0, iters+1, 10)[:len(mean_rewards)]

    results = np.vstack([ mean_rewards, std_rewards]).T

    
    np.savetxt("rewards_stats_NPGbuffer_noIS.txt", results, fmt="%.4f")

    
    fig_reward = plt.figure(figsize=(6,4))
    plt.plot(iterations, mean_rewards, color='b', linewidth=2, label="Mean reward")
    plt.fill_between(iterations, mean_rewards - std_rewards, mean_rewards + std_rewards, color='b', alpha=0.2, label="Min-Max range")
    plt.xlabel("Iteration")
    plt.ylabel("Average reward per agent")
    plt.title("Reward trend over iterations")
    plt.legend()
    plt.grid(True)
    fig_reward.savefig('reward_trend_noIS_mean3_{}.png'.format(N), bbox_inches='tight')

    
    plot_accuracies = np.array(list(itertools.zip_longest(*raw_accuracies, fillvalue=np.nan))).T
    clrs = sns.color_palette("husl", 3)
    piters = list(range(plot_accuracies.shape[1]))

    fig2 = plt.figure(figsize=(6,4))
    for i in range(len(plot_accuracies)):
        plt.plot(piters, plot_accuracies[i])
    plt.grid(linewidth=0.6)
    plt.gca().set(xlabel='Iterations',ylabel='L1-accuracy', title='Policy Gradient: agents = {}, runs = {}, $\eta$ = {}'.format(N, runs,eta))
    #plt.show()
    fig2.savefig('individual_runs_n{}.png'.format(N),bbox_inches='tight')
    #plt.close()
    
    plot_accuracies = np.nan_to_num(plot_accuracies)
    if runs==1:
        return plot_accuracies, fig2, None, None
    pmean = list(map(statistics.mean, zip(*plot_accuracies)))
    pstdv = list(map(statistics.stdev, zip(*plot_accuracies)))
    
    fig1 = plt.figure(figsize=(6,4))
    ax = sns.lineplot(piters, pmean, color = clrs[0],label= 'Mean L1-accuracy')
    ax.fill_between(piters, np.subtract(pmean,pstdv), np.add(pmean,pstdv), alpha=0.3, facecolor=clrs[0],label="1-standard deviation")
    ax.legend()
    plt.grid(linewidth=0.6)
    plt.gca().set(xlabel='Iterations',ylabel='L1-accuracy', title='Policy Gradient: agents = {}, runs = {}, $\eta$ = {}'.format(N, runs,eta))
    #plt.show()
    fig1.savefig('avg_runs_n{}.png'.format(N),bbox_inches='tight')
    #plt.close()
    
    #print(densities)

    fig3, ax = plt.subplots()
    index = np.arange(D)
    bar_width = 0.35
    opacity = 1

    #print(len(index))
    #print(len(densities[0]))
    rects1 = plt.bar(index, densities[0], bar_width,
    alpha= .7 * opacity,
    color='b',
    label='Safe state')

    rects2 = plt.bar(index + bar_width, densities[1], bar_width,
    alpha= opacity,
    color='r',
    label='Distancing state')

    plt.gca().set(xlabel='Facility',ylabel='Average number of agents', title='Policy Gradient: agents = {}, runs = {}, $\eta$ = {}'.format(N,runs,eta))
    plt.xticks(index + bar_width/2, ('A', 'B', 'C', 'D'))
    plt.legend()
    fig3.savefig('facilities_n{}.png'.format(N),bbox_inches='tight')
   #plt.close()
    #plt.show()

    return plot_accuracies, fig1, fig2, fig3


runs = 4
max_iters = 500
plot_accuracies, fig1, fig2, fig3 = full_experiment(runs,max_iters,0.001,20,600, algorithm="NPG_noIS")