# %%
import jax
import jax.numpy as jnp
import optax
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from scipy.optimize import curve_fit
import os

# %%
# matplotlib and seaborn settings
plt.rcParams.update({
    "text.usetex": False,
    "font.family": "serif",
    "font.serif": ["Palatino"],
    "axes.labelsize": 12,
    "font.size": 12,
    "legend.fontsize": 10,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
})
sns.set_style("darkgrid")

# %%
# Number of states and actions
num_states = 5
num_actions = 2

# Discount factor
gamma = 0.9

# States and driver's test states
states = np.arange(num_states)
eval_states = [0, 2]  # s1 and s3

# Actions
actions = [0, 1]

# Initial state distribution (uniform) as a JAX array
mu = jnp.ones(num_states) / num_states

# %%
# Similarity function f(s, s')
def similarity(s, s_prime):
    if s == s_prime:
        return 1.0
    elif (s in eval_states) and (s_prime in eval_states):
        return 0.8
    elif abs(s - s_prime) == 1:
        return 0.5
    else:
        return 0.2

# Build the similarity matrix f[s, s']
f_list = []
for s in range(num_states):
    row = []
    for s_prime in eval_states:
        row.append(similarity(s, s_prime))
    f_list.append(row)
f = jnp.array(f_list)

# %%
# Number of trials and beta values for averaging
NUM_TRIALS = 1000  # Increase number of trials
# beta_values = jnp.array([0, 0.1, 0.316, 1, 3.16, 10, 31.6, 100, 316, 1000])
beta_values = jnp.array([0] + [10**x for x in np.arange(-1, 3.01, 0.1)])
finer_betas = jnp.array([0] + [10**x for x in np.arange(-1, 3.01, 0.01)])

# We'll store results after the experiment.
# (We will get arrays of shape (num_beta, NUM_TRIALS))
# and then compute the mean and SEM.
# (For further processing we convert them to NumPy arrays.)
# No explicit storage arrays are needed for a vmap version.

# Define the loss function.
def loss_fn(theta, R, P, f, mu, beta):
    # Compute policy (softmax over actions)
    pi = jax.nn.softmax(theta, axis=1)
    # Build A and b:
    #   A[s, :] = one-hot(s) - gamma * sum_{a} pi[s,a] * P[s,a,:]
    #   b[s] = sum_{a} pi[s,a] * R[s,a]
    A = jnp.eye(num_states) - gamma * jnp.einsum('sa,sap->sp', pi, P)
    b = jnp.sum(pi * R, axis=1)
    
    # Solve for V in A * V = b.
    V = jnp.linalg.solve(A, b)
    
    # Compute average V over evaluation states (using the similarity matrix f)
    numerator = jnp.matmul(f, V[jnp.array(eval_states)])
    denominator = jnp.sum(f, axis=1)
    avg_V = numerator / denominator

    # Error vector and predictability penalty.
    e = avg_V - V
    predictability_penalty = beta * jnp.sum(e ** 2)

    # Expected return weighted by state distribution.
    expected_reward = jnp.dot(mu, V)
    
    objective = expected_reward - predictability_penalty
    loss = -objective
    return loss, (e, expected_reward)

# Create a value-and-gradient function.
value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

# Define an optimizer update function.
def update_step(theta, opt_state, R, P, f, mu, beta, optimizer):
    (loss_val, (e, expected_reward)), grads = value_and_grad_fn(theta, R, P, f, mu, beta)
    updates, opt_state = optimizer.update(grads, opt_state)
    theta = optax.apply_updates(theta, updates)
    return theta, opt_state, loss_val, e, expected_reward

# Define the training loop function.
def train_loop(theta, opt_state, R, P, f, mu, beta, num_epochs, optimizer):
    def step_fn(carry, _):
        theta, opt_state = carry
        theta, opt_state, loss_val, e, expected_reward = update_step(theta, opt_state, R, P, f, mu, beta, optimizer)
        return (theta, opt_state), (loss_val, e, expected_reward)
    (theta, opt_state), outputs = jax.lax.scan(step_fn, (theta, opt_state), None, length=num_epochs)
    return theta

# Mark num_epochs and optimizer as static.
train_loop = jax.jit(train_loop, static_argnames=("num_epochs", "optimizer"))

# Define a function that runs one trial for a given beta.
def run_trial(beta, key):
    # Generate rewards R and transitions P using key splits.
    key, subkey = jax.random.split(key)
    R = jax.random.uniform(subkey, shape=(num_states, num_actions))
    key, subkey = jax.random.split(key)
    P = jax.random.uniform(subkey, shape=(num_states, num_actions, num_states))
    P = P / jnp.sum(P, axis=2, keepdims=True)
    
    # Initialize theta and optimizer.
    theta_init = jnp.zeros((num_states, num_actions))
    optimizer = optax.adam(learning_rate=0.1)
    opt_state = optimizer.init(theta_init)
    
    # Run training loop.
    num_epochs = 500  # training epochs per trial
    # Pass beta without converting it
    theta_final = train_loop(theta_init, opt_state, R, P, f, mu, beta, num_epochs, optimizer)
    
    # Evaluate final metrics.
    # Again, pass beta directly.
    _, (e, expected_reward) = loss_fn(theta_final, R, P, f, mu, beta)
    predictability_gap_squared = jnp.sum(e**2)
    return predictability_gap_squared, expected_reward


# We want to vectorize run_trial over the trials.
# First, generate a set of PRNG keys for all trials.
main_key = jax.random.PRNGKey(0)
trial_keys = jax.random.split(main_key, NUM_TRIALS)

# Now, define a function that, for a given beta, runs all trials in parallel.
def run_for_beta(beta):
    # vmap over the trial keys.
    vec_run = jax.vmap(lambda key: run_trial(beta, key))
    return vec_run(trial_keys)  # Returns (predictability_gaps, expected_rewards), each of shape (NUM_TRIALS,)

# Vectorize over beta values.
# Here we use jax.vmap over beta as well. The inner function runs all trials for one beta.
results = jax.vmap(run_for_beta)(beta_values)
# results is a tuple: (predictability_gaps, expected_rewards) each of shape (len(beta_values), NUM_TRIALS)
all_predictability_gap_squared_vec, all_expected_rewards_vec = results

# Convert results to numpy arrays.
all_predictability_gap_squared_np = np.array(all_predictability_gap_squared_vec)
all_expected_rewards_np = np.array(all_expected_rewards_vec)

# %%
# Compute mean and standard error.
mean_predictability_gaps_squared = np.mean(all_predictability_gap_squared_np, axis=1)
sem_predictability_gaps_squared = np.std(all_predictability_gap_squared_np, axis=1, ddof=1) / np.sqrt(NUM_TRIALS)
mean_expected_rewards = np.mean(all_expected_rewards_np, axis=1)
sem_expected_rewards = np.std(all_expected_rewards_np, axis=1, ddof=1) / np.sqrt(NUM_TRIALS)

# --- model for expected return (we’ll call it "logistic") ----------------
def logistic(beta, J_min, J_max, beta_half, h):
    """
    Saturating power-law (“Hill”) curve,
    but we’ll label it logistic in the legend.
    """
    return J_min + (J_max - J_min) / (1 + (beta/beta_half)**h)

# --- initial guesses for [J_min, J_max, beta_half, h] --------------------
# Pick J_min≈lowest mean, J_max≈highest, beta_half≈midpoint, h≈1
p0_j = [
    np.min(mean_expected_rewards),
    np.max(mean_expected_rewards),
    10.0,
    1.0
]

# --- fit logistic to data -----------------------------------------------
popt_j, pcov_j = curve_fit(
    logistic,
    beta_values,
    mean_expected_rewards,
    p0=p0_j,
    maxfev=10000
)
J_min, J_max, beta_half_j, h_j = popt_j

# --- report parameters + R² ---------------------------------------------
print("Expected‑return logistic fit:")
print(f"  J_min     = {J_min:.4f}")
print(f"  J_max     = {J_max:.4f}")
print(f"  beta_half = {beta_half_j:.3g}")
print(f"  h         = {h_j:.3f}")

res_j  = mean_expected_rewards - logistic(beta_values, *popt_j)
ss_res = np.sum(res_j**2)
ss_tot = np.sum((mean_expected_rewards - mean_expected_rewards.mean())**2)
r2_j   = 1 - ss_res/ss_tot
print(f"  R²        = {r2_j:.4f}")

# --- now plot -----------------------------------------------------------
# Create results directory if it doesn't exist
os.makedirs('results', exist_ok=True)

y = mean_expected_rewards
yerr = sem_expected_rewards
fig, ax = plt.subplots(figsize=(6,4))
ax.errorbar(
    beta_values,
    y,
    yerr=yerr,
    fmt='s',
    ms=4,
    capsize=3,
    color='orange',
    alpha=0.8
)
# ax.plot(
#     finer_betas,
#     logistic(finer_betas, *popt_j),
#     lw=2.5,
#     label='Logistic fit'
# )
ax.set_xscale('symlog')
ax.set_yscale('log')
ax.set_xlabel(r'$\beta$')
ax.set_ylabel(r'$J_\pi$')
# ax.set_title(f'Expected Return (mean ± SEM, {NUM_TRIALS} trials)')
# ax.legend()
plt.tight_layout()
plt.savefig('results/beta_vs_expected_reward.pdf')
plt.show()


# --- your x‑ and y‑data (already computed) ------------------------------
beta = beta_values                          # 1‑D array of length n
y     = mean_predictability_gaps_squared            # same length as beta
yerr  = sem_predictability_gaps_squared             # optional, for plotting

# --- model ----------------------------------------------------------------
def hill(beta, G_min, G_max, beta_half, h):
    return G_min + (G_max - G_min) / (1 + (beta/beta_half)**h)

# --- initial guesses ------------------------------------------------------
p0 = [0.05, 0.32, 10.0, 1.0]

# --- fit ------------------------------------------------------------------
popt, pcov = curve_fit(hill, beta, y, p0=p0, maxfev=10000)
G_min, G_max, beta_half, h = popt
print(f"Predictability logistic fit:\n"
      f"  G_min      = {G_min:.4f}\n"
      f"  G_max      = {G_max:.4f}\n"
      f"  beta_half  = {beta_half:.3g}\n"
      f"  h          = {h:.3f}")

# goodness‑of‑fit (coefficient of determination)
residuals = y - hill(beta, *popt)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((y - y.mean())**2)
r2 = 1 - ss_res/ss_tot
print(f"R² = {r2:.4f}")


# %%
# Plot Beta vs. Predictability Gap (mean ± SEM)
fig, ax = plt.subplots(figsize=(6, 4))
ax.errorbar(beta_values, y, yerr=yerr, fmt='o', ms=4, capsize=3, alpha=0.8)
# beta_fine = np.logspace(np.log10(beta.min()), np.log10(beta.max()), 400)
# ax.plot(finer_betas, logistic(finer_betas, *popt), lw=2.5, label='Logistic fit')
ax.set_xscale('symlog')
ax.set_yscale('log')
ax.set_xlabel(r'$\beta$')
ax.set_ylabel(r'$G_\pi^2$')
# ax.set_title(f'Predictability Gap (mean ± SEM, {NUM_TRIALS} trials)')
# ax.legend()
plt.tight_layout()
plt.savefig('results/beta_vs_predictability_gap.pdf')
plt.show()

# %%
# Plot Beta vs. Expected Reward (mean ± SEM)
# plt.figure(figsize=(8, 6))
# plt.errorbar(beta_values, mean_expected_rewards, yerr=sem_expected_rewards, marker='s', color='orange', capsize=4)
# plt.xlabel('$\\beta$')
# plt.ylabel('Expected Return ($J_{\\pi}$)')
# plt.grid(True)
# plt.xscale('symlog')
# plt.yscale('log')
# plt.title(f'Expected Return (mean ± SEM, {NUM_TRIALS} trials)')
# plt.savefig('results/beta_vs_expected_reward.pdf')
# plt.show()
