import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
from docplex.mp.model import Model
import seaborn as sns; sns.set()

#np.random.seed(40)
base_policy_prediction_list = []
no_is = []
naive_is = []
for seed in range(40,45):
    np.random.seed(seed)
    n=3
    action_spaces = np.array([10,10,10])
    reward = np.random.uniform(0, 0.2, size=(action_spaces))

    initial_pi = [np.ones(i) / i for i in action_spaces]
    for pi in initial_pi:
        pi/=np.sum(pi)
    eta = 0.1
    tau = 0
    eta = 0.1
    N = 5000       
    sample_times = 20
    T = 500
    pi_star = np.copy(initial_pi)
    new_pi_star = np.copy(initial_pi)
    r_old_storage = [np.zeros_like(pi[a]) for a in range(n)]

    #pi = np.copy(initial_pi)
    pi = [np.copy(p) for p in initial_pi]
    #pi_old = [np.copy(p) for p in initial_pi]

    theta = np.copy(initial_pi)
    pi_old = [np.copy(p) for p in pi]
    old_theta = [np.copy(t) for t in theta]
    for agent in range(n):
        theta[agent] = np.log(theta[agent])

    grad_update_interval = 100
    def pi_minus_i_prob(pi_dist, a_minus_i, agent_idx):
        prob = 1.0
        idx = 0
        for j in range(n):
            if j != agent_idx:
                prob *= pi_dist[j][a_minus_i[idx]]
                idx += 1
        return prob

    def sample_q(pi_old, agent_idx):
    
        return np.random.choice(len(pi_old[agent_idx]), p=pi_old[agent_idx])
    sampled_actions = {agent: [] for agent in range(n)}

    NE_gaps = []

    for iter in tqdm(range(N)):
        NE_gap = []
        next_pi = [np.copy(p) for p in pi]
        next_theta = [np.copy(t) for t in theta]

        
        if (iter) % grad_update_interval == 0:
            for agent in range(n):
                sampled_actions[agent] = []
                for _ in range(sample_times):
                    a_minus_i = [sample_q(pi_old, j) for j in range(n) if j != agent]
                    sampled_actions[agent].append(a_minus_i)

        
        for agent in range(n):
            action_dim = int(action_spaces[agent])
            estimated_r = np.zeros(action_dim)

            for a_minus_i in sampled_actions[agent]:
                w = max(pi_minus_i_prob(pi, a_minus_i, agent) / pi_minus_i_prob(pi_old, a_minus_i, agent), 1e-7)
                for a_i in range(action_dim):
                    a_full = list(a_minus_i)
                    a_full.insert(agent, a_i)
                    payoff = reward[tuple(a_full)]
                    estimated_r[a_i] += w * payoff

            estimated_r /= len(sampled_actions[agent])

            r = estimated_r - np.dot(estimated_r, pi[agent])

            r_true = np.copy(reward)
            r_true = np.transpose(r_true)
            for left_agent in range(agent):
                r_true = np.matmul(r_true, pi[left_agent])
            r_true = np.transpose(r_true)
            for right_agent in range(n-1, agent, -1):
                r_true = np.matmul(r_true, pi[right_agent])
            r_true = r_true - np.dot(r_true, pi[agent])
            NE_gap.append(np.max(r_true))
            if iter % T == 0:
                r_old_storage[agent] = np.copy(r)
            next_theta[agent] += eta * r
            logits = next_theta[agent] - np.max(next_theta[agent])
            exp_logits = np.exp(logits)
            denom = np.sum(exp_logits) + 1e-10 
            next_pi[agent] = exp_logits / denom
            # next_pi[agent] = np.exp(next_theta[agent]) - np.max(next_theta[agent])
            # next_pi[agent] /= np.sum(next_pi[agent])

        pi = [np.copy(p) for p in next_pi]
        theta = [np.copy(t) for t in next_theta]

        
        if iter % T != 0 and iter % grad_update_interval == 0:
            old_theta[agent] += eta * r_old_storage[agent]
            logits = old_theta[agent] - np.max(old_theta[agent])
            exp_logits = np.exp(logits)
            denom = np.sum(exp_logits) + 1e-10
            pi_old[agent] = exp_logits / denom
            
            # pi_old[agent] = np.exp(old_theta[agent])
            # pi_old[agent] /= np.sum(pi_old[agent])
        elif iter % T == 0:
            
            pi_old = [np.copy(p) for p in pi]
            old_theta = [np.copy(t) for t in theta]

        NE_gaps.append(NE_gap)

    NE_gaps = np.array(NE_gaps)
    NPG_NE_gaps = np.sum(NE_gaps, axis=1)
    base_policy_prediction_list.append(NPG_NE_gaps)
    

    initial_pi = [np.ones(i) / i for i in action_spaces]
    for pi in initial_pi:
        pi/=np.sum(pi)
    eta = 0.1
    tau = 0
    eta = 0.1
    N = 5000       
    sample_times = 100
    T = 500
    pi_star = np.copy(initial_pi)
    new_pi_star = np.copy(initial_pi)
    r_old_storage = [np.zeros_like(pi[a]) for a in range(n)]

    #pi = np.copy(initial_pi)
    pi = [np.copy(p) for p in initial_pi]
    #pi_old = [np.copy(p) for p in initial_pi]

    theta = np.copy(initial_pi)
    pi_old = [np.copy(p) for p in pi]
    old_theta = [np.copy(t) for t in theta]
    for agent in range(n):
        theta[agent] = np.log(theta[agent])

    grad_update_interval = 500
    def pi_minus_i_prob(pi_dist, a_minus_i, agent_idx):
        prob = 1.0
        idx = 0
        for j in range(n):
            if j != agent_idx:
                prob *= pi_dist[j][a_minus_i[idx]]
                idx += 1
        return prob

    def sample_q(pi_old, agent_idx):
    
        return np.random.choice(len(pi_old[agent_idx]), p=pi_old[agent_idx])
    sampled_actions = {agent: [] for agent in range(n)}

    NE_gaps = []

    for iter in tqdm(range(N)):
        NE_gap = []
        next_pi = [np.copy(p) for p in pi]
        next_theta = [np.copy(t) for t in theta]

        
        if (iter) % grad_update_interval == 0:
            for agent in range(n):
                sampled_actions[agent] = []
                for _ in range(sample_times):
                    a_minus_i = [sample_q(pi_old, j) for j in range(n) if j != agent]
                    sampled_actions[agent].append(a_minus_i)

        
        for agent in range(n):
            action_dim = int(action_spaces[agent])
            estimated_r = np.zeros(action_dim)

            for a_minus_i in sampled_actions[agent]:
                w = max(pi_minus_i_prob(pi, a_minus_i, agent) / pi_minus_i_prob(pi_old, a_minus_i, agent), 1e-7)
                for a_i in range(action_dim):
                    a_full = list(a_minus_i)
                    a_full.insert(agent, a_i)
                    payoff = reward[tuple(a_full)]
                    estimated_r[a_i] += w * payoff

            estimated_r /= len(sampled_actions[agent])

            r = estimated_r - np.dot(estimated_r, pi[agent])

            r_true = np.copy(reward)
            r_true = np.transpose(r_true)
            for left_agent in range(agent):
                r_true = np.matmul(r_true, pi[left_agent])
            r_true = np.transpose(r_true)
            for right_agent in range(n-1, agent, -1):
                r_true = np.matmul(r_true, pi[right_agent])
            r_true = r_true - np.dot(r_true, pi[agent])
            NE_gap.append(np.max(r_true))

            #NE_gap.append(np.max(r))
            if iter % T == 0:
                r_old_storage[agent] = np.copy(r)
            next_theta[agent] += eta * r
            logits = next_theta[agent] - np.max(next_theta[agent])
            exp_logits = np.exp(logits)
            denom = np.sum(exp_logits) + 1e-10
            next_pi[agent] = exp_logits / denom

        pi = [np.copy(p) for p in next_pi]
        theta = [np.copy(t) for t in next_theta]

        
        if iter % T != 0 and iter % grad_update_interval == 0:
            old_theta[agent] += eta * r_old_storage[agent]
            logits = old_theta[agent] - np.max(old_theta[agent])
            exp_logits = np.exp(logits)
            denom = np.sum(exp_logits) + 1e-10
            pi_old[agent] = exp_logits / denom
        elif iter % T == 0:
            
            pi_old = [np.copy(p) for p in pi]
            old_theta = [np.copy(t) for t in theta]

        NE_gaps.append(NE_gap)

    NE_gaps = np.array(NE_gaps)
    NPG_regularization_NE_gaps = np.sum(NE_gaps, axis=1)
    naive_is.append(NPG_regularization_NE_gaps)

    

    initial_pi = [np.ones(i) / i for i in action_spaces]
    for pi in initial_pi:
        pi/=np.sum(pi)
    eta = 0.1
    tau = 0
    eta = 0.1
    N = 5000       
    sample_times = 100
    T = 500
    pi_star = np.copy(initial_pi)
    new_pi_star = np.copy(initial_pi)
    r_old_storage = [np.zeros_like(pi[a]) for a in range(n)]

    
    pi = [np.copy(p) for p in initial_pi]
    

    theta = np.copy(initial_pi)
    pi_old = [np.copy(p) for p in pi]
    old_theta = [np.copy(t) for t in theta]
    for agent in range(n):
        theta[agent] = np.log(theta[agent])

    grad_update_interval = 500
    def pi_minus_i_prob(pi_dist, a_minus_i, agent_idx):
        prob = 1.0
        idx = 0
        for j in range(n):
            if j != agent_idx:
                prob *= pi_dist[j][a_minus_i[idx]]
                idx += 1
        return prob

    def sample_q(pi_old, agent_idx):
    
        return np.random.choice(len(pi_old[agent_idx]), p=pi_old[agent_idx])
    sampled_actions = {agent: [] for agent in range(n)}

    NE_gaps = []

    for iter in tqdm(range(N)):
        NE_gap = []
        next_pi = [np.copy(p) for p in pi]
        next_theta = [np.copy(t) for t in theta]

        
        if (iter) % grad_update_interval == 0:
            for agent in range(n):
                sampled_actions[agent] = []
                for _ in range(sample_times):
                    a_minus_i = [sample_q(pi_old, j) for j in range(n) if j != agent]
                    sampled_actions[agent].append(a_minus_i)

        
        for agent in range(n):
            action_dim = int(action_spaces[agent])
            estimated_r = np.zeros(action_dim)

            for a_minus_i in sampled_actions[agent]:
                #w = pi_minus_i_prob(pi, a_minus_i, agent) / pi_minus_i_prob(pi_old, a_minus_i, agent)
                w = 1 ### No importance sampling
                for a_i in range(action_dim):
                    a_full = list(a_minus_i)
                    a_full.insert(agent, a_i)
                    payoff = reward[tuple(a_full)]
                    estimated_r[a_i] += w * payoff

            estimated_r /= len(sampled_actions[agent])

            r = estimated_r - np.dot(estimated_r, pi[agent])

            r_true = np.copy(reward)
            r_true = np.transpose(r_true)
            for left_agent in range(agent):
                r_true = np.matmul(r_true, pi[left_agent])
            r_true = np.transpose(r_true)
            for right_agent in range(n-1, agent, -1):
                r_true = np.matmul(r_true, pi[right_agent])
            r_true = r_true - np.dot(r_true, pi[agent])
            NE_gap.append(np.max(r_true))
            if iter % T == 0:
                r_old_storage[agent] = np.copy(r)
            next_theta[agent] += eta * r
            logits = next_theta[agent] - np.max(next_theta[agent])
            exp_logits = np.exp(logits)
            denom = np.sum(exp_logits) + 1e-10 
            next_pi[agent] = exp_logits / denom

        pi = [np.copy(p) for p in next_pi]
        theta = [np.copy(t) for t in next_theta]
        if iter % T != 0 and iter % grad_update_interval == 0:
            old_theta[agent] += eta * r_old_storage[agent]
            logits = old_theta[agent] - np.max(old_theta[agent])
            exp_logits = np.exp(logits)
            denom = np.sum(exp_logits) + 1e-10
            pi_old[agent] = exp_logits / denom
        elif iter % T == 0:
            
            pi_old = [np.copy(p) for p in pi]
            old_theta = [np.copy(t) for t in theta]

        NE_gaps.append(NE_gap)

    NE_gaps = np.array(NE_gaps)
    PG_log_barrier_NE_gaps = np.sum(NE_gaps, axis=1)
    no_is.append(PG_log_barrier_NE_gaps)


end = 10000
x = range(5000)
plt.figure(figsize=(8, 5))
base_policy_prediction_list = np.array(base_policy_prediction_list)
no_is = np.array(no_is)
naive_is = np.array(naive_is)
def save_mean_std(arr, filename):
    arr = np.array(arr)  
    mean = np.mean(arr, axis=0)
    std = np.std(arr, axis=0)
    out = np.stack([mean, std], axis=1)
    np.savetxt(filename, out, fmt="%.6f")  

save_mean_std(base_policy_prediction_list, "base_policy_prediction_list.txt")
save_mean_std(no_is, "no_is.txt")
save_mean_std(naive_is, "naive_is.txt")

mean = base_policy_prediction_list.mean(axis = 0)
std = base_policy_prediction_list.std(axis = 0)
plt.plot(mean, label="Base Policy Prediction")
plt.fill_between(x, 
                 mean-0.4 * std, mean+0.4 * std,
                 alpha=0.2)
mean = no_is.mean(axis = 0)
std = no_is.std(axis = 0)
plt.plot(mean, label="No Importance Sampling")
plt.fill_between(x, 
                 mean-0.4 * std, mean+0.4 * std,
                 alpha=0.2)
mean = naive_is.mean(axis = 0)
std = naive_is.std(axis = 0)
plt.plot(mean, label="Naive Importance Sampling")
plt.fill_between(x, 
                 mean-0.4 * std, mean+0.4 * std,
                 alpha=0.2)

plt.xlabel("Iterations")
plt.ylabel("NE-gap")
plt.legend()
plt.tight_layout()
plt.savefig("my_figure_5.png", dpi=300)
#plt.show()