import json
from functools import partial
from typing import Dict

import jax, jax.numpy as jnp
import numpy as np
import optax
from flax.metrics.tensorboard import SummaryWriter
from flax.training.train_state import TrainState
from brax import envs

from wrappers import (
    LogWrapper, BraxGymnaxWrapper, VecEnv,
    NormalizeVecObservation, NormalizeVecReward, ClipAction,
)

from models import ActorCritic, AdvNet
from rapo_math import (
    kl_hat_from_values, robust_expectation_dual, project_eta_to_delta, straight_through_eta,
)
from theta_env import build_ensemble_kernels, infer_obs_mode
from sampling import sample_snext_weighted_allk, tilt_theta_weights_allk
from ppo_loss import ppo_loss
from utils import Transition, default_run_dir, save_params


def compute_robust_targets(cfg: Dict, key, step_allK, w_curr, net, params, advnet, adv_params,
                           obs_buf, act_buf, rew_buf, done_buf, done_tail):
    T, N, D = obs_buf.shape
    B = T * N
    next_nonterm = jnp.concatenate([1.0 - done_buf[1:], (1.0 - done_tail[None, :])], axis=0)
    s = obs_buf.reshape(B, D); a = act_buf.reshape(B, -1)
    r = rew_buf.reshape(B); msk = next_nonterm.reshape(B)

    s_next, key = sample_snext_weighted_allk(step_allK, w_curr, key, s, a, cfg["M_NEXT"])
    Vn = net.apply(params, s_next.reshape(-1, D))[1].reshape(B, cfg["M_NEXT"])

    eta_t = advnet.apply(adv_params, s, a if cfg["ADV_USE_ACTION"] else jnp.zeros_like(a))
    eta_star = jax.vmap(project_eta_to_delta, in_axes=(0, 0, None, None, None))(Vn, eta_t, cfg["KL_DELTA"], cfg["KL_EPS"], 25)
    eta_used = straight_through_eta(eta_t, eta_star)

    v_rob = robust_expectation_dual(Vn, eta_used)
    y = (r + cfg["GAMMA"] * msk * v_rob).reshape(T, N)

    stats = {
        "rapo/eta_mean": jnp.mean(eta_used),
        "rapo/kl_eta_mean": jnp.mean(kl_hat_from_values(Vn, eta_used)[0]),
        "rapo/vrob_mean": jnp.mean(v_rob),
        "rapo/y_mean": jnp.mean(y),
    }
    return y, key, (Vn, s, a, r, msk), stats


def make_train(config: dict):
    B = int(config["NUM_STEPS"] * config["NUM_ENVS"])
    config["NUM_UPDATES"] = int(config["TOTAL_TIMESTEPS"] // B)
    config["MINIBATCH_SIZE"] = int(B // config["NUM_MINIBATCHES"])

    env, env_params = BraxGymnaxWrapper(config["ENV_NAME"], backend=config.get("BRAX_BACKEND", "positional")), None
    env = LogWrapper(env); env = ClipAction(env); env = VecEnv(env)
    if config["NORMALIZE_ENV"]:
        env = NormalizeVecObservation(env); env = NormalizeVecReward(env, config["GAMMA"])

    obs_shape = env.observation_space(env_params).shape
    act_dim = int(env.action_space(env_params).shape[0])

    total_steps = config["NUM_UPDATES"] * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"]
    lr = optax.linear_schedule(config["LR"], 0.0, total_steps) if config["ANNEAL_LR"] else config["LR"]
    tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(lr, eps=1e-5))
    tx_adv = optax.adam(config["ADV_LR"], eps=1e-5)

    tcfg = dict(
        THETA_PRIOR_MASS_STD=config["THETA_PRIOR_MASS_STD"],
        THETA_PRIOR_INERTIA_STD=config["THETA_PRIOR_INERTIA_STD"],
        THETA_PRIOR_FRICTION_STD=config["THETA_PRIOR_FRICTION_STD"],
        THETA_PRIOR_TORQUE_STD=config["THETA_PRIOR_TORQUE_STD"],
        THETA_PRIOR_COM_STD=config["THETA_PRIOR_COM_STD"],
        THETA_SCALE_MIN=config["THETA_SCALE_MIN"],
        THETA_SCALE_MAX=config["THETA_SCALE_MAX"],
        THETA_COM_ABSMAX=config["THETA_COM_ABSMAX"],
        M_NEXT=config["M_NEXT"], ADV_USE_ACTION=config["ADV_USE_ACTION"],
        KL_DELTA=config["KL_DELTA"], KL_EPS=config["KL_EPS"], GAMMA=config["GAMMA"],
    )

    base_env = envs.get_environment(config["ENV_NAME"], backend=config.get("BRAX_BACKEND", "positional"))
    obs_mode = infer_obs_mode(obs_shape[0], int(base_env.sys.q_size()), int(base_env.sys.qd_size()))

    def train(rng):
        run_dir = default_run_dir(config)
        writer = SummaryWriter(run_dir)
        writer.text("config/json", json.dumps(config, indent=2), step=0)

        net = ActorCritic(act_dim, activation=config["ACTIVATION"])
        rng, k = jax.random.split(rng); params = net.init(k, jnp.zeros(obs_shape))
        state = TrainState.create(apply_fn=net.apply, params=params, tx=tx)

        advnet = AdvNet(use_action=config["ADV_USE_ACTION"])
        rng, k = jax.random.split(rng); adv_params = advnet.init(k, jnp.zeros((1, obs_shape[0])), jnp.zeros((1, act_dim)))
        adv_opt_state = tx_adv.init(adv_params)

        rng, k = jax.random.split(rng)
        step_allK, K, _ = build_ensemble_kernels(base_env.sys, int(config["NUM_MODEL_ENVS"]), k, tcfg,
                                                 backend=config.get("BRAX_BACKEND", "positional"),
                                                 obs_mode=obs_mode)
        w_prior = jnp.ones((K,), jnp.float32) / K
        w_curr = w_prior

        rng, k = jax.random.split(rng)
        reset_keys = jax.random.split(k, config["NUM_ENVS"])
        obs, env_state = env.reset(reset_keys, env_params)

        def _env_step(carry, _):
            st, es, last_obs, rng = carry
            rng, k1, k2 = jax.random.split(rng, 3)
            pi, val = st.apply_fn(st.params, last_obs)
            act = pi.sample(seed=k1); logp = pi.log_prob(act)
            step_keys = jax.random.split(k2, config["NUM_ENVS"])
            nxt, es, rew, done, info = env.step(step_keys, es, act, env_params)
            tr = Transition(done, act, val, rew, logp, last_obs, info)
            return (st, es, nxt, rng), tr

        def _robust_gae(y, vals, done):
            def step(gae, tup):
                y_t, v_t, d_t = tup
                delta = y_t - v_t
                gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1.0 - d_t) * gae
                return gae, gae
            _, adv = jax.lax.scan(step, jnp.zeros_like(y[0]), (y, vals, done), reverse=True, unroll=16)
            return adv, adv + vals

        @partial(jax.jit, static_argnums=(7,))
        def adv_update(adv_params, adv_opt_state, Vn, s, a, r, msk, adv_steps):
            def loss_parts(p):
                eta_t = advnet.apply(p, s, a if config["ADV_USE_ACTION"] else jnp.zeros_like(a))
                eta_star = jax.vmap(project_eta_to_delta, in_axes=(0, 0, None, None, None))(Vn, eta_t, config["KL_DELTA"], config["KL_EPS"], 25)
                eta_used = straight_through_eta(eta_t, eta_star)
                v_rob = robust_expectation_dual(Vn, eta_used)
                kl_resid, _ = kl_hat_from_values(Vn, eta_used)
                pen = jnp.maximum(kl_resid - config["KL_DELTA"], 0.0) ** 2
                y = r + config["GAMMA"] * msk * v_rob
                loss = jnp.mean(y) + config["ADV_KL_PENALTY"] * jnp.mean(pen)
                aux = {"rapo/eta_mean": jnp.mean(eta_used),
                       "rapo/kl_eta_mean": jnp.mean(kl_resid),
                       "rapo/vrob_mean": jnp.mean(v_rob),
                       "rapo/y_mean": jnp.mean(y),
                       "rapo/adv_loss": loss}
                return loss, aux

            def body(carry, _):
                p, opt_state = carry
                (loss, aux), grads = jax.value_and_grad(loss_parts, has_aux=True)(p)
                updates, opt_state = tx_adv.update(grads, opt_state, p)
                p = optax.apply_updates(p, updates)
                return (p, opt_state), aux

            (p_new, opt_state_new), aux_hist = jax.lax.scan(body, (adv_params, adv_opt_state), None, length=adv_steps)
            aux_mean = jax.tree.map(lambda x: jnp.mean(x), aux_hist)
            return p_new, opt_state_new, aux_mean

        def _update_epoch(carry, _):
            st, traj, adv, targets, rng = carry
            rng, key = jax.random.split(rng)
            B = int(config["NUM_STEPS"] * config["NUM_ENVS"])
            traj_flat = jax.tree.map(lambda x: x.reshape((B,) + x.shape[2:]), traj)
            adv_flat, tgt_flat = adv.reshape((B,)), targets.reshape((B,))
            idx = jax.random.permutation(key, B); take = lambda x: jnp.take(x, idx, axis=0)
            traj_shuf = jax.tree.map(take, traj_flat); adv_shuf, tgt_shuf = take(adv_flat), take(tgt_flat)
            M, MB = int(config["NUM_MINIBATCHES"]), int(config["MINIBATCH_SIZE"])
            traj_mb = jax.tree.map(lambda x: x.reshape((M, MB) + x.shape[1:]), traj_shuf)
            adv_mb  = adv_shuf.reshape(M, MB); tgt_mb = tgt_shuf.reshape(M, MB)

            def _one_mb(st, batch):
                tb, ab, tb2 = batch
                def loss_fn(p): return ppo_loss(net, p, tb, ab, tb2, config["CLIP_EPS"], config["ENT_COEF"], config["VF_COEF"])
                (total, aux), g = jax.value_and_grad(loss_fn, has_aux=True)(st.params)
                up, opt_state = st.tx.update(g, st.opt_state, st.params)
                st = st.replace(params=optax.apply_updates(st.params, up), opt_state=opt_state)
                return st, aux

            st, aux_mb = jax.lax.scan(_one_mb, st, (traj_mb, adv_mb, tgt_mb))
            aux_epoch = jax.tree.map(lambda x: jnp.mean(x, axis=0), aux_mb)
            return (st, traj, adv, targets, rng), aux_epoch

        @jax.jit
        def update_step(st, es, last_obs, rng, adv_p, adv_opt_state, w_curr):
            B = int(config["NUM_STEPS"] * config["NUM_ENVS"])
            (st, es, last_obs, rng), traj = jax.lax.scan(_env_step, (st, es, last_obs, rng), None, config["NUM_STEPS"])
            y, rng, aux, y_stats = compute_robust_targets(
                tcfg, rng, step_allK, w_curr, net, st.params, advnet, adv_p,
                traj.obs, traj.action, traj.reward, traj.done, traj.done[-1]
            )
            Vn, s_flat, a_flat, r_flat, msk_flat = aux
            adv_p, adv_opt_state, adv_stats = adv_update(adv_p, adv_opt_state, Vn, s_flat, a_flat, r_flat, msk_flat, config["ADV_STEPS"])
            A, R = _robust_gae(y, traj.value, traj.done)
            A = (A - A.mean()) / (A.std() + 1e-8)
            (st, _, _, _, rng), ppo_stats = jax.lax.scan(_update_epoch, (st, traj, A, R, rng), None, config["UPDATE_EPOCHS"])
            ppo_stats = jax.tree.map(lambda x: jnp.mean(x, axis=0), ppo_stats)

            val_apply = lambda x: net.apply(st.params, x)[1]
            w_new, rng = tilt_theta_weights_allk(step_allK, val_apply, w_prior, rng,
                                                 traj.obs.reshape(B, -1), traj.action.reshape(B, -1),
                                                 config["RHO_THETA"], int(config["THETA_UPDATE_SUBSAMPLE"]))
            from .rapo_math import to_valid_prob
            w_curr = to_valid_prob(w_new)
            w_ent = -jnp.sum(w_curr * jnp.log(w_curr + 1e-12))
            w_kl = jnp.sum(w_curr * (jnp.log(w_curr + 1e-12) - jnp.log(w_prior + 1e-12)))

            rollout_ret_mean = jnp.mean(jnp.sum(traj.reward, axis=0))

            scalars = {"rapo/w_entropy": w_ent, "rapo/w_kl_prior": w_kl, **y_stats, **adv_stats, **ppo_stats}
            return (st, es, last_obs, rng, adv_p, adv_opt_state, w_curr), {
                "info": traj.info, "scalars": scalars, "rollout_return_mean": rollout_ret_mean
            }

        best_mean = -1e9
        best_ckpt_actor = f"{run_dir}/best_actor.msgpack"
        best_ckpt_adv   = f"{run_dir}/best_advnet.msgpack"
        final_ckpt_actor = f"{run_dir}/final_actor.msgpack"
        final_ckpt_adv   = f"{run_dir}/final_advnet.msgpack"
        SAVE_EVERY = config.get("SAVE_EVERY_UPDATES", 10)
        PRINT_EVERY = config.get("PRINT_EVERY_UPDATES", 1)

        runner_state = (state, env_state, obs, rng, adv_params, adv_opt_state, w_curr)
        for upd in range(int(config["NUM_UPDATES"])):
            runner_state, outm = update_step(*runner_state)
            info = outm["info"]; scalars = outm["scalars"]
            rollout_ret_mean = float(np.asarray(outm["rollout_return_mean"]))

            gstep = (upd + 1) * config["NUM_ENVS"] * config["NUM_STEPS"]
            for k, v in scalars.items():
                from math import isnan
                writer.scalar(k, float(v), gstep)

            idx = np.asarray(info["returned_episode"])
            save_metric = rollout_ret_mean
            if idx.any():
                rets = np.asarray(info["returned_episode_returns"])[idx]
                steps = (np.asarray(info["timestep"])[idx] * config["NUM_ENVS"]).astype(np.int64)
                writer.scalar("rollout/return_mean", float(rets.mean()), gstep)
                writer.scalar("rollout/return_std",  float(rets.std() if rets.size > 1 else 0.0), gstep)
                for r, s in zip(rets, steps):
                    writer.scalar("rollout/return", float(r), int(s))
                save_metric = float(rets.mean())
            else:
                writer.scalar("rollout/return_mean", save_metric, gstep)
                writer.scalar("rollout/return_std",  0.0, gstep)

            writer.scalar("rollout/save_metric", save_metric, gstep)

            if (upd + 1) % PRINT_EVERY == 0:
                print(f"[train] update={upd+1}  global_step={gstep}  save_metric={save_metric:.2f}")

            if (upd + 1) % SAVE_EVERY == 0:
                if save_metric > best_mean:
                    best_mean = save_metric
                    save_params(best_ckpt_actor, runner_state[0].params)
                    save_params(best_ckpt_adv,   runner_state[4])
                    print(f"[ckpt] best@upd={upd+1} metric={best_mean:.2f} -> {best_ckpt_actor}")

            if config.get("DEBUG", True):
                def cb(info_cb):
                    idx_cb = info_cb["returned_episode"]
                    ret_cb = info_cb["returned_episode_returns"][idx_cb]
                    ts_cb  = info_cb["timestep"][idx_cb] * config["NUM_ENVS"]
                    for t in range(ret_cb.shape[0]):
                        print(f"global step={int(ts_cb[t])}, episodic return={float(ret_cb[t])}")
                jax.debug.callback(cb, info)

        save_params(final_ckpt_actor, runner_state[0].params)
        save_params(final_ckpt_adv,   runner_state[4])
        writer.flush(); writer.close()
        print(f"[ckpt] final -> {final_ckpt_actor}")
        return {"runner_state": runner_state, "best_mean_return": best_mean, "run_dir": run_dir}

    return train