import numpy as np
import random
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import sympy as sp
from scipy.optimize import root_scalar


# Settings BT model R(x,a) linear  x^T * W \in R^|a|
context_dim = 5 # R^5
num_actions = 10
action_dim = 5
num_iterations = 20
batch_size = 5
action_batch_size = 5
eta = [1]
num_eval_contexts = 10
np.random.seed(35)
lambda_reg = 0.5# the parameter in the covariance matrix
beta = [-1,0.1,0,-2] 
turns = 5
x = sp.Symbol('x')
y = sp.Symbol('y')
#f = (x-1)**2+x*sp.log(x)
#f = x * sp.log(x)
f = x * sp.log(x) -sp.log(x)
seed = [42,15,25,35,20,45]
np.random.seed(42)

# Initialize design matrices (identity for numerical stability)
A = [np.eye(context_dim) for _ in range(num_actions)]
n_vars = context_dim * action_dim
pi0 = np.ones(num_actions)/num_actions

action_vectors = [np.abs(np.random.rand(action_dim)) for _ in range(num_actions)]
# Bounds for each variable: (lower_bound, upper_bound)
bounds = [(1e-7, None)] * n_vars  # None means no upper bound, lower bound is 0
# Latent reward model (unknown to learner)
true_reward_weights = np.abs(np.random.rand(context_dim, action_dim))

def true_reward(context, action):
    return context @ true_reward_weights @ action_vectors[action]
    
def phi(x, a):#x\otimes a R(x,a)=x^TM a =<\theta,\phi(x,a)>
    return np.kron(x, a)
    
def phi_pi0(x):#\nu
    phis = np.array([phi(x, action_vectors[i]) for i in range(num_actions)])
    return pi0 @ phis

def build_Sigma(preference_data):
    d_k = action_dim * context_dim
    Sigma = lambda_reg * np.eye(d_k)
    for x, a1, a2, pref, w in preference_data:
        v = phi(x, action_vectors[a1]) - phi(x, action_vectors[a2])
        Sigma += np.outer(v, v)
    return Sigma

def compute_bonus(preference_data, x, action,beta):# compute the bouns of optimism
    if beta == 0:
        return 0
    # 1) Covariance and its inverse
    Sigma     = build_Sigma(preference_data)
    Sigma_inv = np.linalg.inv(Sigma)
    # 2) Feature difference
    phia      = phi(x, action)
    diff      = phia - phipi0
    # 3) bonus w.r.t. Sigma^{-1}
    bonus     = beta * np.sqrt(diff @ Sigma_inv @ diff)# if \beta = 0 then it equals to greedy
    return bonus

# Preference sampling using Bradley-Terry model
def sample_preference_bt(context, a1, a2):
    r1 = true_reward(context, a1)
    r2 = true_reward(context, a2)
    p = np.exp(r1) / (np.exp(r1) + np.exp(r2))
    return 1 if np.random.rand() < p else 0

# Numerically stable softmax
def softmax(logits):
    logits = logits - np.max(logits)
    exps = np.exp(logits)
    sum_exps = np.sum(exps)
    if sum_exps == 0 or not np.isfinite(sum_exps):
        return np.ones_like(logits) / len(logits)
    return exps / sum_exps

def get_policy(context,reward_weights,eta,f): # given R(x,a), and pi_0, according to the planning oracle, compute a policy
    Rxa = np.zeros(num_actions)
    for i in range(num_actions):
        Rxa[i]= context @ reward_weights @ action_vectors[i]
    return optimal_policy(f,Rxa,eta)

def get_policy_opt(context,c_reward_weights,preference_data,beta,eta,f):
    Rxa = np.zeros(num_actions)
    for i in range(num_actions):
        Rxa[i]= context @ c_reward_weights @ action_vectors[i]+compute_bonus(preference_data, context, action_vectors[i],beta)
        Rxa[i]=np.abs(Rxa[i])
    return optimal_policy(f,Rxa,eta)

def get_sample_policy(context,reward_weights,eta,f): # given R(x,a), and pi_0, according to the planning oracle, compute a policy
    Rxa = np.zeros(num_actions)
    for i in range(num_actions):
        Rxa[i]= context @ reward_weights @ action_vectors[i]
    return sample_policy(f,Rxa,eta)
    

def optimal_policy(f, reward,eta):
    lam = sp.Symbol('lambda')
    x = sp.symbols('x')
    df = sp.diff(f, x)
    #print(df)
    inverse = sp.solve(sp.Eq(df, y), x)
    #print(inverse)
    df_inv = inverse[0]
    expr = sum(a * df_inv.subs(y, eta*r - eta*lam) for a, r in zip(pi0, reward))
    equation = sp.Eq(expr, 1)
    #print(equation)
    #sol = sp.nsolve(equation, lam, -1.0)
    for guess in [6,5,4,3.5,3,2.5,2, 1.5, 1,0.5, 0, -1, -2,-3,-4,-5]:
        try:
            sol = sp.nsolve(equation, lam, guess)
            #print(f"Found λ from guess {guess}: {sol}")
            break
        except Exception as e:
            pass
            #print(f"Guess {guess} failed: {e}")
    else:
        print("All guesses failed. Could not solve for λ.")
    #print("Lambda =", sp.re(sol))
    val=sum(a * df_inv.subs(y, eta*r - eta*sp.re(sol)) for a, r in zip(pi0, reward))
    #print(val.evalf())
    
    pi = np.array([
        (a * df_inv.subs(y, eta*r - eta*sp.re(sol))).evalf()
        for a, r in zip(pi0, reward)
    ], dtype=float)
    #print(pi)
    return pi


def sample_policy(f, reward,eta):
    lam = sp.Symbol('lambda')
    x = sp.symbols('x')
    df = sp.diff(f, x)
    #print(df)
    inverse = sp.solve(sp.Eq(df, y), x)
    #print(inverse)
    df_inv = inverse[0]
    expr = sum(a * df_inv.subs(y, eta*r - eta*lam) for a, r in zip(pi0, reward))
    equation = sp.Eq(expr, 1)
    #print(equation)
    #sol = sp.nsolve(equation, lam, -1.0)
    #for guess in [-6,-5,-3,-2, -1, 0, 1, 2,4,6,7,3]:
    for guess in [6,5,4,3.5,3,2.5,2, 1.5, 1,0.5, 0, -1, -2,-3,-4,-5]:
        try:
            sol = sp.nsolve(equation, lam, guess)
            
            break
        except Exception as e:
            pass
            #print(f"Guess {guess} failed: {e}")
    else:
        print("All guesses failed. Could not solve for λ.")
    #print("Lambda =", sp.re(sol))
    #val=sum(a * df_inv.subs(y, eta*r - eta*sp.re(sol)) for a, r in zip(pi0, reward))
    #print(val.evalf())
    dh = sp.diff(df_inv,y)
    pi = np.array([
        (a * dh.subs(y, eta*r - eta*sp.re(sol))).evalf()
        for a, r in zip(pi0, reward)
    ], dtype=float)
    #print(dh)
    pi = np.abs(pi)
    Tx = pi.sum()
    pi_prime = pi / Tx
    pi_po = pi * np.exp(reward)
    pi_ne = pi * np.exp(-reward)
    probab = pi_po.sum() * pi_ne.sum() / (1+pi_po.sum() * pi_ne.sum())
    #print(Tx,pi_po.sum() * pi_ne.sum())
    wx = Tx * (1+pi_po.sum() * pi_ne.sum())
    pi_po = pi_po / pi_po.sum()
    pi_ne = pi_ne / pi_ne.sum()
    
    return pi_prime, pi_po, pi_ne, probab, wx


def f_div(pi,f):
    x = sp.symbols('x')
    f_np   = sp.lambdify(x, f, modules='numpy')
    ratios = pi / pi0
    vals   = f_np(ratios)              # works on arrays
    total  = float(np.sum(pi0 * vals)) # ∑ pi0[i] * f(pi[i]/pi0[i])
    return total

def optimal_gap(pi,f,eta):
    pi_star = optimal_policy(f, r_1,eta)
    #print(pi_star)
    div = f_div(pi_star, f)
    j_star = np.sum(pi_star * r_1)-(1/eta)* div
    j0 = np.sum(pi * r_1)-(1/eta)*f_div(pi, f)
    return j_star-j0

def objective(f,eta):
    opt_pi = optimal_policy(f, r_1,eta)
    val = np.dot(r_1,opt_pi)
    if np.isnan(val):
        return np.inf
    return val - target
    
# create lists for each (\eta,\beta) pair    
multi_suboptimal_gaps = [[[] for _ in range(len(eta))] for _ in range(len(beta))]
multi_cumulative_gaps = [[[] for _ in range(len(eta))] for _ in range(len(beta))]

for _ in range(num_eval_contexts):
    phi0 = 0
    x =np.random.rand(context_dim)
    phi0 = phi0 + phi_pi0(x)
phipi0 = phi0/num_eval_contexts
#print("nu",phipi0)


for k in range(turns):
    #print("in turn:", k)
    # Store preference data
    preference_data = [[[] for _ in range(len(eta))] for _ in range(len(beta))]
    reference_reward_weights = np.ones((context_dim, action_dim))

    # Evaluation tracker
    j_gap =  [[[] for _ in range(len(eta))] for _ in range(len(beta))]
    cumulative_gaps =  [[[] for _ in range(len(eta))] for _ in range(len(beta))]
    cumulative_sum = np.zeros((len(beta),len(eta)))
    reward_weights = [[reference_reward_weights.copy() for _ in range(len(eta))] for _ in range(len(beta))]
    # Main loop
    pi1=pi0
    for iteration in range(num_iterations):# online setting
    # Step 1: Collect new data
        for _ in range(batch_size):
            context = np.random.rand(context_dim)
            for i in range(len(beta)):
                for j in range(len(eta)):
                    if beta[i] == -1:#funciton
                        pi_prime, pi_po, pi_ne, probab, wx = get_sample_policy(context, reward_weights[i][j],eta[j],f)
                        for l in range(action_batch_size):
                            p_signal = 1 if random.random() < probab else 0
                            if p_signal == 1:
                                a1 = np.random.choice(num_actions, p=pi_po)
                                a2 = np.random.choice(num_actions, p=pi_ne)# sample from different pi
                            else:
                                a1 = np.random.choice(num_actions, p=pi_prime)
                                a2 = np.random.choice(num_actions, p=pi_prime)
                            for r in range(3):
                                pref = sample_preference_bt(context, a1, a2)
                                preference_data[i][j].append((context, a1, a2, pref,wx))
                    elif beta[i]==-2: #uniform sampling
                        for t in range(action_batch_size):
                            a1 = np.random.choice(num_actions, p=pi0)# sample from different pi
                            a2 = np.random.choice(num_actions, p=pi0)# sample from different pi
                            for r in range(3):
                                pref = sample_preference_bt(context, a1, a2)
                                preference_data[i][j].append((context, a1, a2, pref,1))
                    else:# beta=0 is greedy and beta=0.1 is optimism
                        pi11 = pi1
                        pi1 = get_policy_opt(context, reward_weights[i][j],preference_data[i][j],beta[i],eta[j],f)
                        for k in range(action_batch_size):
                            a1 = np.random.choice(num_actions, p=pi11)# sample from different pi
                            a2 = np.random.choice(num_actions, p=pi1)# sample from different pi
                            for r in range(3):
                                pref = sample_preference_bt(context, a1, a2)
                                preference_data[i][j].append((context, a1, a2, pref,1))

    # Step 2: Define negative log-likelihood (BT loss) and Step 3: Optimize weights using MLE
        for i in range(len(beta)):
            for j in range(len(eta)):
                def loss_fn(flat_weights):
                    W = flat_weights.reshape(context_dim, action_dim)
                    total_loss = 0
                    for x, a1, a2, pref, w in preference_data[i][j]:
                        if w<0:
                            print("wrong")
                        r1 = x @ W @ action_vectors[a1]
                        r2 = x @ W @ action_vectors[a2]
                        p_bt = (np.exp(r1)) / (np.exp(r1)+ np.exp(r2))
                        total_loss -= w * np.log(p_bt if pref == 1 else (1 - p_bt))
                    return total_loss / len(preference_data[i][j])
                res = minimize(loss_fn, reward_weights[i][j].flatten(), method='SLSQP', bounds=bounds)
                reward_weights[i][j] = res.x.reshape(context_dim, action_dim)

    # Step 4: Evaluate J(pi*) - J(pi) = E_x E_a
        j_pi = np.zeros((len(beta),len(eta)))
        j_star = np.zeros(len(eta))
        for _ in range(num_eval_contexts):
            context = np.random.rand(context_dim)
            rewards = np.array([true_reward(context, a) for a in range(num_actions)])
            for j in range(len(eta)):
                pi_star = get_policy(context,true_reward_weights,eta[j],f)
                #print(pi_star)
                div = f_div(pi_star, f)
                j_star[j] += np.sum(pi_star * rewards)-(1/eta[j])* div

                for i in range(len(beta)):
                    pi = get_policy(context,reward_weights[i][j],eta[j],f)
                    r_pi = np.sum(pi * rewards)
                    div = f_div(pi, f)
                    j_pi[i][j] += r_pi - eta[j] ** -1 * div
                    if j_pi[i][j]>j_star[j]:
                        print(beta[i],pi,"\n r:",r_pi,"\n J:",j_pi[i][j],"\n J_s",j_star[j])
        for i in range(len(beta)):
            for j in range(len(eta)):
                avg_j_pi = j_pi[i][j] / num_eval_contexts
                avg_j_star = j_star[j] / num_eval_contexts
                gap = avg_j_star - avg_j_pi
                if gap<0:
                    print("error")
                j_gap[i][j].append(gap)
                cumulative_sum[i][j] += gap
                cumulative_gaps[i][j].append(cumulative_sum[i][j])
    for i in range(len(beta)):
        for j in range(len(eta)):
            multi_suboptimal_gaps[i][j].append(j_gap[i][j])
            multi_cumulative_gaps[i][j].append(cumulative_gaps[i][j])
# Plotting
color = ['tab:blue','tab:orange','tab:gray','tab:purple']
plt.figure(figsize=(8,6))
for i in range(len(beta)):
    for j in range(len(eta)):
        results = np.array(multi_suboptimal_gaps[i][j])  # Shape: (num_runs, num_iterations)
        means = results.mean(axis=0)
        stds = results.std(axis=0)  # For error bars
        x = np.arange(num_iterations)
        if beta[i] == 0:
            plt.errorbar(x, means, yerr=stds, color=color[i], label = fr"Greedy", fmt='-o', capsize=4)
        elif beta[i] == -1:
            plt.errorbar(x, means, yerr=stds, color=color[i], label = fr"Function", fmt='-o', capsize=4)
        elif beta[i] ==-2:
            plt.errorbar(x, means, yerr=stds, color=color[i], label = fr"Vanilla", fmt='-o', capsize=4)
        else:
            plt.errorbar(x, means, yerr=stds, color=color[i], label = fr"Optimism", fmt='-o', capsize=4)
plt.xlabel("Iteration")
plt.ylabel("Suboptimality Gap Over Time")
plt.title(r"Performance of different algorithm")
plt.grid(True)
plt.legend()
plt.show()

plt.figure(figsize=(8,6))
for i in range(len(beta)):
    for j in range(len(eta)):
        results = np.array(multi_cumulative_gaps[i][j])  # Shape: (num_runs, num_iterations)
        means = results.mean(axis=0)
        stds = results.std(axis=0)  # For error bars
        x = np.arange(num_iterations)
        if beta[i] == 0:
            plt.errorbar(x, means, yerr=stds, color=color[i], label = fr"Greedy", fmt='-o', capsize=4)
        elif beta[i] == -1:
            plt.errorbar(x, means, yerr=stds,color=color[i], label = fr"Function", fmt='-o', capsize=4)
        elif beta[i] == -2:
            plt.errorbar(x, means, yerr=stds, color=color[i], label = fr"Vanilla", fmt='-o', capsize=4)
        else:
            plt.errorbar(x, means, yerr=stds, color=color[i], label = fr"Optimism with $\beta$ = {beta[i]}", fmt='-o', capsize=4)
plt.xlabel("Iteration")
plt.ylabel("Suboptimality Gap Over Time")
plt.title(r"Impact of $\beta$ and $\eta$")
plt.grid(True)
plt.legend()
plt.show()
# Save data
array1 = np.array(multi_suboptimal_gaps)
array2 = np.array(multi_cumulative_gaps)
# Save as .npz
np.savez('data_f.npz', array1=array1, array2=array2)

