import jax
import jax.numpy as jnp
import numpy as np
import optax
import matplotlib.pyplot as plt
import seaborn as sns
import os

# matplotlib and seaborn settings
plt.rcParams.update({
    "text.usetex": False,
    "font.family": "serif",
    "font.serif": ["Palatino"],
    "axes.labelsize": 12,
    "font.size": 12,
    "legend.fontsize": 10,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
})
sns.set_style("darkgrid")


# -- problem constants --
num_states   = 5
num_actions  = 2
gamma        = 0.9
driver_states = jnp.array([0, 2])
mu = jnp.ones(num_states) / num_states

# build similarity matrix once on host
def _sim(s, sp):
    if s == sp:
        return 1.0
    elif (s in driver_states) and (sp in driver_states):
        return 0.8
    elif jnp.abs(s - sp) == 1:
        return 0.5
    else:
        return 0.2

f = jnp.stack([jnp.array([_sim(s, sp) for sp in driver_states])
               for s in range(num_states)])

# JIT‑compiled single trial (unchanged)
@jax.jit
def run_one_trial(beta, key):
    key, sub = jax.random.split(key)
    R = jax.random.uniform(sub, (num_states, num_actions))
    key, sub = jax.random.split(key)
    P = jax.random.uniform(sub, (num_states, num_actions, num_states))
    P = P / P.sum(axis=2, keepdims=True)
    theta = jnp.zeros((num_states, num_actions))
    opt  = optax.adam(1e-1)
    opt_state = opt.init(theta)

    def loss_fn(theta):
        pi = jax.nn.softmax(theta, axis=1)
        A  = jnp.eye(num_states) - gamma * jnp.einsum('sa,sap->sp', pi, P)
        b  = (pi * R).sum(axis=1)
        V  = jnp.linalg.solve(A, b)
        avg_V = (f @ V[driver_states]) / f.sum(axis=1)
        e = avg_V - V
        penalty = beta * jnp.sum(e*e)
        J = mu @ V
        return -(J - penalty), (e, J, avg_V)

    @jax.jit
    def train_loop(theta, opt_state):
        def step(carry, _):
            θ, st = carry
            (loss, (_, J, _)), grads = jax.value_and_grad(loss_fn, has_aux=True)(θ)
            updates, st = opt.update(grads, st)
            θ = optax.apply_updates(θ, updates)
            return (θ, st), None
        (θ_final, _), _ = jax.lax.scan(step, (theta, opt_state), None, length=500)
        return θ_final

    θ_opt = train_loop(theta, opt_state)
    (_, (e, J, avg_V)) = loss_fn(θ_opt)
    return jnp.sum(e*e), J, (mu @ avg_V)


# vectorize over trials
run_trials = jax.jit(lambda beta, keys: jax.vmap(lambda k: run_one_trial(beta, k))(keys))
# vectorize and JIT over beta
run_all    = jax.jit(lambda betas, keys: jax.vmap(lambda b: run_trials(b, keys))(betas))

# prepare keys and betas
seeds        = 1000
main_key     = jax.random.PRNGKey(0)
trial_keys   = jax.random.split(main_key, seeds)
beta_values  = jnp.array([0] + [10**x for x in np.arange(-1, 3.01, 0.1)])

# run everything in one go
all_G2, all_J, all_Jhat = run_all(beta_values, trial_keys)
# all_G2 has shape (n_betas, n_trials)

# --- compute mean and SEM ---
n_trials = trial_keys.shape[0]

# mean of G^2
mean_G2 = jnp.mean(all_G2, axis=1)
sem_G2  = jnp.std(all_G2, axis=1, ddof=1) / jnp.sqrt(n_trials)

# mean of (J - Jhat)^2
err_sq        = (all_J - all_Jhat)**2
mean_err_sq   = jnp.mean(err_sq, axis=1)
sem_err_sq    = jnp.std(err_sq, axis=1, ddof=1) / jnp.sqrt(n_trials)

# convert to NumPy for plotting
beta_np        = np.array(beta_values)
mean_G2_np     = np.array(mean_G2)
sem_G2_np      = np.array(sem_G2)
mean_err_sq_np = np.array(mean_err_sq)
sem_err_sq_np  = np.array(sem_err_sq)



# --- plotting ---
# Create results directory if it doesn't exist
os.makedirs('results', exist_ok=True)

fig, ax = plt.subplots(figsize=(6,4))

# Plot G^2
ax.errorbar(
    beta_np,
    mean_G2_np,
    yerr=sem_G2_np,
    fmt='o',
    ms=4,
    capsize=3,
    label=r'$G^2$'
)

# Plot (J - Jhat)^2
ax.errorbar(
    beta_np,
    mean_err_sq_np,
    yerr=sem_err_sq_np,
    fmt='s',
    ms=4,
    capsize=3,
    label=r'$(J-\hat J)^2$'
)

# Axis scales
ax.set_xscale('symlog')
ax.set_yscale('log')

# Labels & legend
ax.set_xlabel(r'$\beta$')
ax.set_ylabel('Mean squared error')
ax.legend(loc='upper right')

# Grid & layout
ax.grid(True, which='both', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig('results/bounds_j_pi_hat.pdf', dpi=300)