import jax
import jax.numpy as jnp
import pickle
import itertools
import zlib

# --- 1. Environment Imports ---
try:
    from gridword_env.gridworld import GridWorldJAX, GridWorldParams
    from synthetic_env.Env import make_finite_mdp_params, FiniteMDPParams
except ImportError as e:
    print(f"Import Error: {e}")
    print("Ensure folders 'gridword_env' and 'synthetic_env' exist with __init__.py")
    raise


# --- 2. Mathematical Helpers ---
def get_pi(theta):
    return jax.nn.softmax(theta, axis=-1)

def compute_exact_v(theta, P, R, mu0, gamma, lam):
    pi = get_pi(theta)
    log_pi = jnp.log(pi + 1e-12)
    P_pi = jnp.einsum("sa,san->sn", pi, P)
    r_tilde = jnp.sum(pi * (R - lam * log_pi), axis=1)
    I = jnp.eye(P.shape[0])
    v_pi = jnp.linalg.solve(I - gamma * P_pi, r_tilde)
    return jnp.dot(mu0, v_pi)

def stable_int_hash(s: str) -> int:
    """Deterministic across runs (unlike Python's built-in hash)."""
    return int(zlib.adler32(s.encode("utf-8")))


# --- 3. Core ERAC Logic ---
def update_step(key, theta, q, params, lr_a, lr_c, actor_batch, critic_batch, H, gamma, lam):
    pi = get_pi(theta)
    log_pi = jnp.log(pi + 1e-12)
    P_pi = jnp.einsum("sa,san->sn", pi, params.P)
    I = jnp.eye(params.P.shape[0])
    d_dist = (1 - gamma) * jnp.linalg.solve((I - gamma * P_pi).T, params.mu0)

    # (optional safety) normalize distribution to avoid rare numerical issues
    d_dist = jnp.clip(d_dist, 0.0)
    d_dist = d_dist / (jnp.sum(d_dist) + 1e-12)

    # --- Critic Update (H steps) ---
    # Use while_loop so H can be a traced value (needed to vmap over configs with different H)
    def critic_cond(state):
        i, _q = state
        return i < H

    def critic_body(state):
        i, current_q = state
        k = jax.random.fold_in(key, i)
        ks, ka, ks2, ka2 = jax.random.split(k, 4)

        s = jax.random.choice(ks, params.P.shape[0], shape=(critic_batch,), p=d_dist)
        a = jax.vmap(lambda s_idx, k_idx: jax.random.choice(k_idx, params.P.shape[1], p=pi[s_idx]))(
            s, jax.random.split(ka, critic_batch))
        s2 = jax.vmap(lambda s_idx, a_idx, k_idx: jax.random.choice(k_idx, params.P.shape[0], p=params.P[s_idx, a_idx]))(
            s, a, jax.random.split(ks2, critic_batch))

        r_tilde = params.R[s, a] - lam * log_pi[s, a]
        v2 = jnp.sum(pi[s2] * (current_q[s2] - lam * log_pi[s2]), axis=1)
        td_err = r_tilde + gamma * v2 - current_q[s, a]

        new_q = current_q.at[s, a].add(lr_c * jnp.mean(td_err))
        return (i + 1, new_q)

    _, new_q = jax.lax.while_loop(critic_cond, critic_body, (jnp.asarray(0, dtype=jnp.int32), q))

    # --- Actor Update ---
    k_a = jax.random.fold_in(key, 999)
    ks_a, ka_a = jax.random.split(k_a)

    s_a = jax.random.choice(ks_a, params.P.shape[0], shape=(actor_batch,), p=d_dist)
    a_a = jax.vmap(lambda s_idx, k_idx: jax.random.choice(k_idx, params.P.shape[1], p=pi[s_idx]))(
        s_a, jax.random.split(ka_a, actor_batch))

    v = jnp.sum(pi * (new_q - lam * log_pi), axis=1)
    adv = new_q[s_a, a_a] - lam * log_pi[s_a, a_a] - v[s_a]
    new_theta = theta.at[s_a, a_a].add(lr_a * jnp.mean(adv / (1 - gamma)))

    j_val = compute_exact_v(new_theta, params.P, params.R, params.mu0, gamma, lam)
    return new_theta, new_q, j_val


# --- 4. Trainer ---
def train_erac_logic(key, params, H, num_iters, gamma, lam, lr_a, lr_c, actor_batch=32, critic_batch=64):
    S, A = params.P.shape[0], params.P.shape[1]

    def scan_fn(carry, step_key):
        theta, q = carry
        next_theta, next_q, j_val = update_step(
            step_key, theta, q, params, lr_a, lr_c, actor_batch, critic_batch, H, gamma, lam
        )
        return (next_theta, next_q), j_val

    init_theta = jnp.zeros((S, A))
    init_theta = init_theta.at[:, 0].set(4.0)
    init_q = jnp.zeros((S, A))

    j0 = compute_exact_v(init_theta, params.P, params.R, params.mu0, gamma, lam)  # scalar
    _, history = jax.lax.scan(scan_fn, (init_theta, init_q), jax.random.split(key, num_iters))

    history = jnp.concatenate([jnp.asarray([j0]), history], axis=0)  # length = num_iters+1
    return history


# JIT: keep num_iters and batch sizes static (but NOT H)
train_erac_jit = jax.jit(train_erac_logic, static_argnums=(3, 8, 9))


# --- 5. Parallel gridsearch over configs AND runs ---
def run_experiment(env_type="gridworld"):
    rng = jax.random.PRNGKey(42)

    NUM_RUNS = 10
    ITERS = 5000
    Hs = [8, 16, 32,64]
    GAMMA, LAMBDA = 0.99, 0.05
    ACTOR_BATCH, CRITIC_BATCH = 128, 256

    LR_ACTOR_VALS = [0.003, 0.01, 0.03,0.1]
    LR_CRITIC_VALS = [0.003, 0.01, 0.03,0.1]

    print(f"\n[Running Grid Search - PARALLEL] {env_type.upper()}")

    # build env params once
    if env_type == "gridworld":
        gw = GridWorldJAX(rows=3, cols=3)
        params = gw.make_params(rng, terminal_states=((2, 2),))
    else:
        params, _ = make_finite_mdp_params(rng, S=8, A=3)

    # enumerate combinations on host (for labeling + stable hashing)
    combos = list(itertools.product(Hs, LR_ACTOR_VALS, LR_CRITIC_VALS))
    labels = [f"H={h}_lra={lra}_lrc={lrc}" for (h, lra, lrc) in combos]

    # pack hyperparams into device arrays
    H_arr   = jnp.asarray([h for (h, _, _) in combos], dtype=jnp.int32)      # (C,)
    lra_arr = jnp.asarray([lra for (_, lra, _) in combos], dtype=jnp.float32)# (C,)
    lrc_arr = jnp.asarray([lrc for (_, _, lrc) in combos], dtype=jnp.float32)# (C,)
    cfg_ids = jnp.asarray([stable_int_hash(lbl) for lbl in labels], dtype=jnp.uint32)  # (C,)

    run_ids = jnp.arange(NUM_RUNS, dtype=jnp.uint32)  # (R,)

    # make keys for every (config, run): shape (C,R,2)
    def make_keys_for_cfg(cfg_id):
        # keys over runs
        def mk(run_id):
            k = jax.random.fold_in(rng, cfg_id)
            k = jax.random.fold_in(k, run_id)
            return k
        return jax.vmap(mk)(run_ids)

    keys_CR = jax.vmap(make_keys_for_cfg)(cfg_ids)  # (C,R,2)

    # run one config over all runs (vectorized over run)
    def run_cfg(keys_R, H, lra, lrc):
        def run_one(k):
            return train_erac_jit(
                k, params, H, ITERS, GAMMA, LAMBDA, lra, lrc, ACTOR_BATCH, CRITIC_BATCH
            )
        return jax.vmap(run_one)(keys_R)  # (R, T)

    # run all configs (vectorized over config)
    run_all_cfgs = jax.jit(lambda keys_CR, H_arr, lra_arr, lrc_arr:
                           jax.vmap(run_cfg)(keys_CR, H_arr, lra_arr, lrc_arr))

    print(f"  Launching batched run: configs={len(labels)} x runs={NUM_RUNS} ...")
    histories_CRT = run_all_cfgs(keys_CR, H_arr, lra_arr, lrc_arr)  # (C,R,T)
    histories_CRT = jax.device_get(histories_CRT)  # bring back once

    # repack into dict: label -> list of runs (each run is list[T])
    results_data = {}
    for i, lbl in enumerate(labels):
        results_data[lbl] = [histories_CRT[i, r].tolist() for r in range(NUM_RUNS)]

    out = f"results_{env_type}_gridsearch.pkl"
    with open(out, "wb") as f:
        pickle.dump(
            {
                "results": results_data,
                "iters": ITERS + 1,
                "env": env_type,
                "num_runs": NUM_RUNS,
                "Hs": Hs,
                "lr_actor_vals": LR_ACTOR_VALS,
                "lr_critic_vals": LR_CRITIC_VALS,
            },
            f,
        )

    print(f"  Saved grid search to {out}")


if __name__ == "__main__":
    run_experiment(env_type="gridworld")
    run_experiment(env_type="synthetic")
