import jax
import jax.numpy as jnp
import pickle
import itertools
import zlib

try:
    from gridword_env.gridworld import GridWorldJAX, GridWorldParams
    from synthetic_env.Env import make_finite_mdp_params, FiniteMDPParams
except ImportError as e:
    print(f"Import Error: {e}")
    raise

def get_pi(theta):
    return jax.nn.softmax(theta, axis=-1)

def compute_exact_q(theta, P, R, gamma, lam):
    pi = get_pi(theta)
    log_pi = jnp.log(pi + 1e-12)
    P_pi = jnp.einsum("san,sa->sn", P, pi)
    r_tilde = jnp.sum(pi * (R - lam * log_pi), axis=1)
    I = jnp.eye(P.shape[0])
    v_pi = jnp.linalg.solve(I - gamma * P_pi, r_tilde)
    q_pi = (R - lam * log_pi) + gamma * jnp.einsum("san,n->sa", P, v_pi)
    return q_pi

def compute_exact_v(theta, P, R, mu0, gamma, lam):
    pi = get_pi(theta)
    log_pi = jnp.log(pi + 1e-12)
    P_pi = jnp.einsum("sa,san->sn", pi, P)
    r_tilde = jnp.sum(pi * (R - lam * log_pi), axis=1)
    I = jnp.eye(P.shape[0])
    v_pi = jnp.linalg.solve(I - gamma * P_pi, r_tilde)
    return jnp.dot(mu0, v_pi)

def stable_int_hash(s: str) -> int:
    return int(zlib.adler32(s.encode("utf-8")))


def update_step(key, theta, q, params, lr_a, lr_c, actor_batch, critic_batch, H, gamma, lam):
    pi = get_pi(theta)
    log_pi = jnp.log(pi + 1e-12)
    P_pi = jnp.einsum("sa,san->sn", pi, params.P)
    I = jnp.eye(params.P.shape[0])
    d_dist = (1 - gamma) * jnp.linalg.solve((I - gamma * P_pi).T, params.mu0)
    d_dist = jnp.clip(d_dist, 1e-12)
    d_dist /= jnp.sum(d_dist)

    def critic_body(state):
        i, current_q = state
        k = jax.random.fold_in(key, i)
        ks, ka, ks2 = jax.random.split(k, 3)
        s = jax.random.choice(ks, params.P.shape[0], shape=(critic_batch,), p=d_dist)
        a = jax.vmap(lambda s_idx, k_idx: jax.random.choice(k_idx, params.P.shape[1], p=pi[s_idx]))(s, jax.random.split(ka, critic_batch))
        s2 = jax.vmap(lambda s_idx, a_idx, k_idx: jax.random.choice(k_idx, params.P.shape[0], p=params.P[s_idx, a_idx]))(s, a, jax.random.split(ks2, critic_batch))
        
        r_tilde = params.R[s, a] - lam * log_pi[s, a]
        v2 = jnp.sum(pi[s2] * (current_q[s2] - lam * log_pi[s2]), axis=1)
        td_err = r_tilde + gamma * v2 - current_q[s, a]
        return (i + 1, current_q.at[s, a].add(lr_c * jnp.mean(td_err)))

    _, new_q = jax.lax.while_loop(lambda x: x[0] < H, critic_body, (0, q))

    k_a = jax.random.fold_in(key, 999)
    ks_a, ka_a = jax.random.split(k_a)
    s_a = jax.random.choice(ks_a, params.P.shape[0], shape=(actor_batch,), p=d_dist)
    a_a = jax.vmap(lambda s_idx, k_idx: jax.random.choice(k_idx, params.P.shape[1], p=pi[s_idx]))(s_a, jax.random.split(ka_a, actor_batch))
    
    v = jnp.sum(pi * (new_q - lam * log_pi), axis=1)
    adv = new_q[s_a, a_a] - lam * log_pi[s_a, a_a] - v[s_a]
    new_theta = theta.at[s_a, a_a].add(lr_a * jnp.mean(adv / (1 - gamma)))
    return new_theta, new_q, compute_exact_v(new_theta, params.P, params.R, params.mu0, gamma, lam)

def update_step_exact_critic(key, theta, q, params, lr_a, lr_c, actor_batch, critic_batch, H, gamma, lam):
    pi = get_pi(theta)
    log_pi = jnp.log(pi + 1e-12)
    true_q = compute_exact_q(theta, params.P, params.R, gamma, lam)
    
    P_pi = jnp.einsum("sa,san->sn", pi, params.P)
    I = jnp.eye(params.P.shape[0])
    d_dist = (1 - gamma) * jnp.linalg.solve((I - gamma * P_pi).T, params.mu0)
    d_dist = jnp.clip(d_dist, 1e-12)
    d_dist /= jnp.sum(d_dist)

    k_a = jax.random.fold_in(key, 888)
    ks_a, ka_a = jax.random.split(k_a)
    s_a = jax.random.choice(ks_a, params.P.shape[0], shape=(actor_batch,), p=d_dist)
    a_a = jax.vmap(lambda s_idx, k_idx: jax.random.choice(k_idx, params.P.shape[1], p=pi[s_idx]))(s_a, jax.random.split(ka_a, actor_batch))

    v = jnp.sum(pi * (true_q - lam * log_pi), axis=1)
    adv = true_q[s_a, a_a] - lam * log_pi[s_a, a_a] - v[s_a]
    new_theta = theta.at[s_a, a_a].add(lr_a * jnp.mean(adv / (1 - gamma)))
    
    return new_theta, true_q, compute_exact_v(new_theta, params.P, params.R, params.mu0, gamma, lam)

def train_logic(key, params, H, num_iters, gamma, lam, lr_a, lr_c, actor_batch, critic_batch, use_exact):
    S, A = params.P.shape[0], params.P.shape[1]
    
    def scan_fn(carry, step_key):
        theta, q = carry
        if use_exact:
            next_theta, next_q, j_val = update_step_exact_critic(step_key, theta, q, params, lr_a, lr_c, actor_batch, critic_batch, H, gamma, lam)
        else:
            next_theta, next_q, j_val = update_step(step_key, theta, q, params, lr_a, lr_c, actor_batch, critic_batch, H, gamma, lam)
        return (next_theta, next_q), j_val

    init_theta = jnp.zeros((S, A)).at[:, 0].set(4.0)
    init_q = jnp.zeros((S, A))
    j0 = compute_exact_v(init_theta, params.P, params.R, params.mu0, gamma, lam)
    _, history = jax.lax.scan(scan_fn, (init_theta, init_q), jax.random.split(key, num_iters))
    return jnp.concatenate([jnp.array([j0]), history])

train_jit = jax.jit(train_logic, static_argnums=(3, 8, 9, 10))

def run_experiment(env_type="gridworld", use_exact=False):
    rng = jax.random.PRNGKey(42)
    NUM_RUNS, ITERS = 50, 5000
    Hs = [1] if use_exact else [8, 16, 32, 64]
    GAMMA, LAMBDA = 0.99, 0.05
    ACTOR_BATCH, CRITIC_BATCH = 128, 256
    LR_ACTOR_VALS = [0.003, 0.01, 0.03, 0.1]
    LR_CRITIC_VALS = [0.0] if use_exact else [0.003, 0.01, 0.03, 0.1]

    suffix = "_EXACT" if use_exact else "_gridsearch"
    
    if env_type == "gridworld":
        rows, cols = 3, 4
        gw = GridWorldJAX(rows=rows, cols=cols)
        params = gw.make_params(rng, terminal_states=((rows-1, cols-1),))
        size_str = f"{rows}x{cols}"
    else:
        num_states, num_actions = 16, 4
        params, _ = make_finite_mdp_params(rng, S=num_states, A=num_actions)
        size_str = f"S{num_states}"

    print(f"\n[Running {env_type.upper()} ({size_str}){suffix}]")

    combos = list(itertools.product(Hs, LR_ACTOR_VALS, LR_CRITIC_VALS))
    labels = [f"H={h}_lra={lra}_lrc={lrc}" for (h, lra, lrc) in combos]
    
    H_arr = jnp.array([c[0] for c in combos], dtype=jnp.int32)
    lra_arr = jnp.array([c[1] for c in combos], dtype=jnp.float32)
    lrc_arr = jnp.array([c[2] for c in combos], dtype=jnp.float32)
    
    keys_CR = jax.random.split(rng, len(combos) * NUM_RUNS).reshape(len(combos), NUM_RUNS, 2)
    
    run_all = jax.vmap(jax.vmap(
        lambda k, h, la, lc: train_jit(k, params, h, ITERS, GAMMA, LAMBDA, la, lc, ACTOR_BATCH, CRITIC_BATCH, use_exact),
        in_axes=(0, None, None, None)), 
        in_axes=(0, 0, 0, 0))

    histories = run_all(keys_CR, H_arr, lra_arr, lrc_arr)
    histories_np = jax.device_get(histories)

    results_data = {lbl: histories_np[i].tolist() for i, lbl in enumerate(labels)}

    out_name = f"results_{env_type}_{size_str}{suffix}.pkl"
    
    with open(out_name, "wb") as f:
        pickle.dump({
            "results": results_data,
            "iters": ITERS + 1,
            "env": env_type,
            "size": size_str,
            "Hs": Hs
        }, f)
    print(f"Done. Saved to {out_name}")

if __name__ == "__main__":
    run_experiment("gridworld", use_exact=False)
    run_experiment("gridworld", use_exact=True)
    run_experiment("synthetic", use_exact=False)
    run_experiment("synthetic", use_exact=True)