# Implement optimism with pi_2 being the enhancer
import numpy as np
from scipy.optimize import minimize
from scipy.special import softmax
import nashpy as nash
import itertools
import matplotlib.pyplot as plt

# basic parameters
np.random.seed(35)
batch_size = 10
context_dim = 5
action_dim = 5
number_actions = 6
num_iterations = 20
kl_coefficient = [1]
action_batch_size =5
num_eval = 30
num_tour = [0,1,3]
turns = 5
seed = [10,20,25,30,35,40,45]
# Generate true parameter tensor (positive values)
true_W = np.abs(np.random.rand(context_dim, action_dim, action_dim))
noise_std = 0.5
reference_W = np.abs(true_W + np.random.rand(context_dim, action_dim, action_dim) * noise_std)

#action_vectors = [np.abs(np.random.randn(action_dim)) for _ in range(number_actions)]
action_vectors = [np.abs(np.random.rand(action_dim)) for _ in range(number_actions)]
pi_0 = np.ones(number_actions)/number_actions # we use pi_0 as the initial policy


bounds = [(1e-6, None)] * (context_dim * action_dim * action_dim)

# Define preference matrix P based on symmetric constraint
def create_P_matrix(theta):
    P = np.zeros((number_actions, number_actions))
    idx = 0
    for i in range(number_actions):
        for j in range(i+1, number_actions):
            P[i, j] = theta[idx]
            P[j, i] = 1 - theta[idx]
            idx += 1
    for i in range(number_actions):
        P[i, i] = 0.5
    return P


# Sample context
def sample_context():
    return np.random.rand(context_dim)

# Compute preference probability using general model
def preference_prob(x, a1_vec, a2_vec, W):
    xW = np.tensordot(x, W, axes=(0, 0))  # shape: (action_dim, action_dim)
    s12 = a1_vec @ xW @ a2_vec
    s21 = a2_vec @ xW @ a1_vec
    prob = s12 / (s12 + s21)  # regularize
    return prob


def solve_fixed_point(context, W,eta, tol=1e-6, max_iter=1000):# we use this function to calcuate the Nash point
    K = number_actions
    P = np.zeros((K,K))
    for i in range(K):
        for j in range(K):
            P[i,j]= preference_prob(context,action_vectors[i],action_vectors[j], W)
    pi = np.ones(number_actions)/number_actions
    for _ in range(max_iter):
        logits = np.array([
            np.sum(pi * P[a]) for a in range(K)
        ])
        new_pi = pi_0 * np.exp(eta * logits)
        new_pi /= np.sum(new_pi)
        if np.linalg.norm(new_pi - pi) < tol:
            #print('converge')
            break
        pi = new_pi
    return pi
def calculate_minizer_value(context, W, pi_hat,eta):# calculate the performance of pi_hat
    K = number_actions
    P = np.zeros((K,K))
    for i in range(K):
        for j in range(K):
            P[i,j]= preference_prob(context,action_vectors[i],action_vectors[j], W)
    pi00 = np.ones(number_actions)/number_actions
    kl1 = np.sum(pi_hat * (np.log(pi_hat) - np.log(pi_0)))
    logits = np.array([np.sum(pi_hat * P[:,a]) for a in range(K)])
    new_pi = pi_0 * np.exp(-eta * logits)
    Z_x= np.sum(new_pi)
    return -eta ** -1 * (kl1 + np.log(Z_x))

# Generate one preference sample
def sample_preference(x,a_1,a_2):
    a1_vec, a2_vec = action_vectors[a1], action_vectors[a2]
    p = preference_prob(x, a1_vec, a2_vec, true_W)
    return 1 if np.random.rand() < p else 0


def get_pi_2_action(context,current_pi1,number_competitor):# we use this function to get the action chosen by enhancer \pi^2
    if number_competitor == 0:
        return np.random.choice(number_actions, p=pi_0)

    # Step 1: Sample n responses independently
    sampled_responses = np.random.choice(number_actions, size=number_competitor, p=current_pi1)
    n = number_competitor
    # Step 2: Tournament: pairwise comparisons
    wins = np.zeros(n)
    for i, j in itertools.combinations(range(n), 2):
        #a_i = action_vectors[sampled_responses[i]]
        #a_j = action_vectors[sampled_responses[j]]
        # preference_fn returns 1 if first is preferred, else 0
        if sample_preference(x, sampled_responses[i], sampled_responses[j]) == 1:
            wins[i] += 1
        else:
            wins[j] += 1

    # Step 3: Pick the action with the most wins
    best_idx = np.argmax(wins)
    best_response_idx = sampled_responses[best_idx]

    return best_response_idx
def add_list(list1,list2,a,b,c):
    return [(a * x + b * y)/c for x, y in zip(list1, list2)]

# create lists for each (\eta,tour_num) pair
multi_suboptimal_gaps = [[[] for _ in range(len(kl_coefficient))] for _ in range(len(num_tour))]
multi_cumulative_gaps = [[[] for _ in range(len(kl_coefficient))] for _ in range(len(num_tour))]

for k in range(turns):
    np.random.seed(seed[k])
# Initialize learned parameter
#learned_W = [[np.ones((context_dim, action_dim, action_dim)) for _ in range(len(kl_coefficient))] for _ in range(len(num_tour))]
    learned_W = [[reference_W.copy() for _ in range(len(kl_coefficient))] for _ in range(len(num_tour))]

# Record suboptimality gap
    suboptimal_gaps = [[[] for _ in range(len(kl_coefficient))] for _ in range(len(num_tour))]
    cum_gaps = [[[] for _ in range(len(kl_coefficient))] for _ in range(len(num_tour))]
    cum_gap= np.zeros((len(num_tour),len(kl_coefficient)))

# Optimization loop
    preference_data = [[[] for _ in range(len(kl_coefficient))] for _ in range(len(num_tour))]
    for t in range(num_iterations):
        for _ in range(batch_size):
            x = np.random.rand(context_dim)
            for i in range(len(num_tour)):
                for j in range(len(kl_coefficient)):
                    pi1 = solve_fixed_point(x,learned_W[i][j],kl_coefficient[j])
                    a1 = np.random.choice(number_actions, p=pi1)# sample from different pi
                    a2 = get_pi_2_action(x,pi1,num_tour[i])
                    for _ in range(action_batch_size):
                        pref = sample_preference(x, a1, a2)
                        preference_data[i][j].append((x, a1, a2, pref))
            
    # Define MLE loss over preference data
        for i in range(len(num_tour)):
            for j in range(len(kl_coefficient)):
                def loss_fn(flat_weights):
                    W = flat_weights.reshape(context_dim, action_dim, action_dim)
                    total_loss = 0
                    for x, a1, a2, pref in preference_data[i][j]:
                        a1_vec, a2_vec = action_vectors[a1], action_vectors[a2]
                        p = preference_prob(x, a1_vec, a2_vec, W)
                        # p = np.clip(p, 1e-9, 1 - 1e-9)
                        total_loss -= np.log(p if pref == 1 else 1 - p)
                    return total_loss / len(preference_data[i][j])
                res = minimize(loss_fn, learned_W[i][j].flatten(), method='L-BFGS-B', bounds=bounds, options={"maxiter": 50})
                learned_W[i][j] = res.x.reshape(context_dim, action_dim, action_dim)


    # Evaluate suboptimality gap
        gap = np.zeros((len(num_tour),len(kl_coefficient)))
        j_pi= np.zeros((len(num_tour),len(kl_coefficient)))
        j_star=1/2
        for _ in range(num_eval):
            x = sample_context()
            for j in range(len(kl_coefficient)):
            #pi_star = solve_fixed_point(x,true_W,kl_coefficient[j])
                for i in range(len(num_tour)):
                    pi_hat = solve_fixed_point(x,learned_W[i][j],kl_coefficient[j]) # learned policy
                    j_pi = calculate_minizer_value(x, true_W, pi_hat,kl_coefficient[j])
                    gap[i][j] += j_star - j_pi

        for i in range(len(num_tour)):
            for j in range(len(kl_coefficient)):
                suboptimal_gaps[i][j].append(gap[i][j] / num_eval)
                cum_gap[i][j] += (gap[i][j] / num_eval)
                cum_gaps[i][j].append(cum_gap[i][j])
    for i in range(len(num_tour)):
        for j in range(len(kl_coefficient)):
            multi_suboptimal_gaps[i][j].append(suboptimal_gaps[i][j])
            multi_cumulative_gaps[i][j].append(cum_gaps[i][j])
# Plot suboptimal gap and regret
plt.figure(figsize=(8,6))
for i in range(len(num_tour)):
    for j in range(len(kl_coefficient)):
        results = np.array(multi_suboptimal_gaps[i][j])  # Shape: (num_runs, num_iterations)
        means = results.mean(axis=0)
        stds = results.std(axis=0)  # For error bars
        x = np.arange(num_iterations)
        plt.errorbar(x, means, yerr=stds, label = fr"tournament_num = {num_tour[i]},$\eta$ = {kl_coefficient[j]}", fmt='-o', capsize=4)
        #plt.plot(multi_suboptimal_gaps[i][j],label = fr"tournament_num = {num_tour[i]},$\eta$ = {kl_coefficient[j]}")
plt.xlabel("Iteration",fontsize=14)
plt.ylabel("Suboptimality Gap Over Time",fontsize=14)
plt.title(r"Impact of tournament number and $\eta$",fontsize=14)
plt.grid(True)
plt.legend(fontsize=14)
plt.show()

plt.figure(figsize=(8,6))
for i in range(len(num_tour)):
    for j in range(len(kl_coefficient)):
        results = np.array(multi_cumulative_gaps[i][j])  # Shape: (num_runs, num_iterations)
        means = results.mean(axis=0)
        stds = results.std(axis=0)  # For error bars
        x = np.arange(num_iterations)
        plt.errorbar(x, means, yerr=stds, label = fr"tournament_num = {num_tour[i]},$\eta$ = {kl_coefficient[j]}", fmt='-o', capsize=4)
        #plt.plot(multi_cumulative_gaps[i][j],label = fr"tournament_num = {num_tour[i]},$\eta$ = {kl_coefficient[j]}")
plt.xlabel("Iteration",fontsize=14)
plt.ylabel("Cumulative Gap Over Time",fontsize=14)
plt.title(r"Impact of tournament number and $\eta$",fontsize=14)
plt.grid(True)
plt.legend(fontsize=14)
plt.show()
# Save data
array1 = np.array(multi_suboptimal_gaps)
array2 = np.array(multi_cumulative_gaps)
# Save as .npy
np.savez('data_general.npz', array1=array1, array2=array2)


