import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize

# Settings BT model R(x,a) linear  x^T * W \in R^|a|
context_dim = 5 # R^5
num_actions = 6
action_dim = 5
num_iterations = 20
batch_size = 10
action_batch_size = 5
eta = [1]
num_eval_contexts = 20
np.random.seed(35)
lambda_reg = 0.5# the parameter in the covariance matrix
beta = [0,0.3,0.5]  # exploration coefficient
turns = 5
seed = [10,15,25,35,38,45]
np.random.seed(42)

# Initialize design matrices (identity for numerical stability)
A = [np.eye(context_dim) for _ in range(num_actions)]
n_vars = context_dim * action_dim
pi0 = np.ones(num_actions)/num_actions

action_vectors = [np.abs(np.random.rand(action_dim)) for _ in range(num_actions)]
# Bounds for each variable: (lower_bound, upper_bound)
bounds = [(1e-7, None)] * n_vars  # None means no upper bound, lower bound is 0
# Latent reward model (unknown to learner)
true_reward_weights = np.abs(np.random.rand(context_dim, action_dim))

def true_reward(context, action):
    return context @ true_reward_weights @ action_vectors[action]

def phi(x, a):#x\otimes a R(x,a)=x^TM a =<\theta,\phi(x,a)>
    return np.kron(x, a)
    
def phi_pi0(x):#\nu
    phis = np.array([phi(x, action_vectors[i]) for i in range(num_actions)])
    return pi0 @ phis

def build_Sigma(preference_data):
    d_k = action_dim * context_dim
    Sigma = lambda_reg * np.eye(d_k)
    for x, a1, a2, pref in preference_data:
        v = phi(x, action_vectors[a1]) - phi(x, action_vectors[a2])
        Sigma += np.outer(v, v)
    return Sigma

def compute_bonus(preference_data, x, action,beta):# compute the bouns of optimism
    # 1) Covariance and its inverse
    Sigma     = build_Sigma(preference_data)
    Sigma_inv = np.linalg.inv(Sigma)
    # 2) Feature difference
    phia      = phi(x, action)
    diff      = phia - phipi0
    # 3) bonus w.r.t. Sigma^{-1}
    bonus     = beta * np.sqrt(diff @ Sigma_inv @ diff)# if \beta = 0 then it equals to greedy
    return bonus

# Preference sampling using Bradley-Terry model
def sample_preference_bt(context, a1, a2):
    r1 = true_reward(context, a1)
    r2 = true_reward(context, a2)
    p = np.exp(r1) / (np.exp(r1) + np.exp(r2))
    return 1 if np.random.rand() < p else 0

# Numerically stable softmax
def softmax(logits):
    logits = logits - np.max(logits)
    exps = np.exp(logits)
    sum_exps = np.sum(exps)
    if sum_exps == 0 or not np.isfinite(sum_exps):
        return np.ones_like(logits) / len(logits)
    return exps / sum_exps

def get_policy(context,reward_weights,eta): # given R(x,a), and pi_0, according to the planning oracle, compute a policy
    Rxa = np.zeros(num_actions)
    for i in range(num_actions):
        Rxa[i]= context @ reward_weights @ action_vectors[i]
    pi_star_logits = np.log(pi0) + eta * Rxa
    return softmax(pi_star_logits)
def get_policy_opt(context,c_reward_weights,preference_data,beta,eta):
    Rxa = np.zeros(num_actions)
    for i in range(num_actions):
        Rxa[i]= context @ c_reward_weights @ action_vectors[i]+compute_bonus(preference_data, context, action_vectors[i],beta)
        Rxa[i]=np.abs(Rxa[i])
    pi_star_logits = np.log(pi0) + eta * Rxa
    return softmax(pi_star_logits)
    
# create lists for each (\eta,\beta) pair    
multi_suboptimal_gaps = [[[] for _ in range(len(eta))] for _ in range(len(beta))]
multi_cumulative_gaps = [[[] for _ in range(len(eta))] for _ in range(len(beta))]

for _ in range(num_eval_contexts):
    phi0 = 0
    x =np.random.rand(context_dim)
    phi0 = phi0 + phi_pi0(x)
phipi0 = phi0/num_eval_contexts
#print("nu",phipi0)

for k in range(turns):
    # Store preference data
    preference_data = [[[] for _ in range(len(eta))] for _ in range(len(beta))]
    reference_reward_weights = np.ones((context_dim, action_dim))

    # Evaluation tracker
    j_gap =  [[[] for _ in range(len(eta))] for _ in range(len(beta))]
    cumulative_gaps =  [[[] for _ in range(len(eta))] for _ in range(len(beta))]
    cumulative_sum = np.zeros((len(beta),len(eta)))
    reward_weights = [[reference_reward_weights.copy() for _ in range(len(eta))] for _ in range(len(beta))]
    # Main loop
    for iteration in range(num_iterations):# online setting
    # Step 1: Collect new data
        for _ in range(batch_size):
            context = np.random.rand(context_dim)
            for i in range(len(beta)):
                for j in range(len(eta)):
                    pi1 = get_policy_opt(context, reward_weights[i][j],preference_data[i][j],beta[i],eta[j])
                    a1 = np.random.choice(num_actions, p=pi1)# sample from different pi
                    a2 = np.random.choice(num_actions, p=pi0)
                    #print(pi1,'\n')
                    for _ in range(action_batch_size):
                        pref = sample_preference_bt(context, a1, a2)
                        preference_data[i][j].append((context, a1, a2, pref))

    # Step 2: Define negative log-likelihood (BT loss) and Step 3: Optimize weights using MLE
        for i in range(len(beta)):
            for j in range(len(eta)):
                def loss_fn(flat_weights):
                    W = flat_weights.reshape(context_dim, action_dim)
                    total_loss = 0
                    for x, a1, a2, pref in preference_data[i][j]:
                        r1 = x @ W @ action_vectors[a1]
                        r2 = x @ W @ action_vectors[a2]
                        p_bt = (np.exp(r1)) / (np.exp(r1)+ np.exp(r2))
                        total_loss -= np.log(p_bt if pref == 1 else (1 - p_bt))
                    return total_loss / len(preference_data[i][j])
                res = minimize(loss_fn, reward_weights[i][j].flatten(), method='SLSQP', bounds=bounds)
                reward_weights[i][j] = res.x.reshape(context_dim, action_dim)

    # Step 4: Evaluate J(pi*) - J(pi) = E_x E_a
        j_pi = np.zeros((len(beta),len(eta)))
        j_star = np.zeros(len(eta))
        for _ in range(num_eval_contexts):
            context = np.random.rand(context_dim)
            rewards = np.array([true_reward(context, a) for a in range(num_actions)])
            for j in range(len(eta)):
                pi_star_logits = np.log(pi0) + eta[j] * rewards  # log(pi_0) + R(x,a)
                pi_star = softmax(pi_star_logits)
                r_star = np.sum(pi_star * rewards)
                kl_star = np.sum(pi_star * (np.log(pi_star) - np.log(pi0)))
                j_star[j] += r_star - eta[j] ** -1 * kl_star
                for i in range(len(beta)):
                    pi = get_policy_opt(context,reward_weights[i][j],preference_data[i][j],beta[i],eta[j])
                    r_pi = np.sum(pi * rewards)
                    kl_pi = np.sum(pi * (np.log(pi) - np.log(pi0)))
                    j_pi[i][j] += r_pi - eta[j] ** -1 * kl_pi
        for i in range(len(beta)):
            for j in range(len(eta)):
                avg_j_pi = j_pi[i][j] / num_eval_contexts
                avg_j_star = j_star[j] / num_eval_contexts
                gap = avg_j_star - avg_j_pi
                j_gap[i][j].append(gap)
                cumulative_sum[i][j] += gap
                cumulative_gaps[i][j].append(cumulative_sum[i][j])
    for i in range(len(beta)):
        for j in range(len(eta)):
            multi_suboptimal_gaps[i][j].append(j_gap[i][j])
            multi_cumulative_gaps[i][j].append(cumulative_gaps[i][j])
# Plotting

plt.figure(figsize=(8,6))
for i in range(len(beta)):
    for j in range(len(eta)):
        results = np.array(multi_suboptimal_gaps[i][j])  # Shape: (num_runs, num_iterations)
        means = results.mean(axis=0)
        stds = results.std(axis=0)  # For error bars
        x = np.arange(num_iterations)
        plt.errorbar(x, means, yerr=stds, label = fr"$\beta$ = {beta[i]},$\eta$ = {eta[j]}", fmt='-o', capsize=4)
plt.xlabel("Iteration")
plt.ylabel("Suboptimality Gap Over Time")
plt.title(r"Impact of $\beta$ and $\eta$")
plt.grid(True)
plt.legend()
plt.show()

plt.figure(figsize=(8,6))
for i in range(len(beta)):
    for j in range(len(eta)):
        results = np.array(multi_cumulative_gaps[i][j])  # Shape: (num_runs, num_iterations)
        means = results.mean(axis=0)
        stds = results.std(axis=0)  # For error bars
        x = np.arange(num_iterations)
        plt.errorbar(x, means, yerr=stds, label = fr"$\beta$ = {beta[i]},$\eta$ = {eta[j]}", fmt='-o', capsize=4)
plt.xlabel("Iteration")
plt.ylabel("Suboptimality Gap Over Time")
plt.title(r"Impact of $\beta$ and $\eta$")
plt.grid(True)
plt.legend()
plt.show()
# Save data
array1 = np.array(multi_suboptimal_gaps)
array2 = np.array(multi_cumulative_gaps)
# Save as .npy
np.savez('data_BT.npz', array1=array1, array2=array2)

