from typing import Callable, Dict, Tuple
import jax
import jax.numpy as jnp
from brax import base


# ---------- obs helpers ----------

def infer_obs_mode(obs_dim, nq, nv):
    if obs_dim == nq + nv: return "full"
    if obs_dim == (nq - 1 + nv): return "omit_1"
    if obs_dim == (nq - 2 + nv): return "omit_2"
    return None


def obs_to_q_qd(obs, tmpl_q, nq, nv, mode):
    if mode == "full":
        q, qd = obs[:nq], obs[nq:nq+nv]
    elif mode == "omit_1":
        q = jnp.concatenate([tmpl_q[:1], obs[:nq-1]], 0)
        qd = obs[nq-1:nq-1+nv]
    elif mode == "omit_2":
        q = jnp.concatenate([tmpl_q[:2], obs[:nq-2]], 0)
        qd = obs[nq-2:nq-2+nv]
    else:
        raise ValueError("unknown obs mode")
    return q, qd


def q_qd_to_obs(q, qd, nq, nv, mode):
    if mode == "full": return jnp.concatenate([q, qd], 0)
    if mode == "omit_1": return jnp.concatenate([q[1:], qd], 0)
    if mode == "omit_2": return jnp.concatenate([q[2:], qd], 0)
    raise ValueError("unknown obs mode")


# ---------- system param application ----------

def apply_theta_to_sys(sys: base.System, theta: Dict[str, jnp.ndarray]) -> base.System:
    upd = {}
    try: upd["link.inertia.mass"] = sys.link.inertia.mass * theta["mass_scale"]
    except: pass
    try: upd["link.inertia.i"] = sys.link.inertia.i * theta["inertia_scale"]
    except: pass
    try:
        pos = sys.link.transform.pos
        upd["link.transform.pos"] = pos.at[1:].add(theta["com_shift"]) if pos.shape[0] > 1 else pos
    except: pass
    try: upd["geom.friction"] = sys.geom.friction * theta["friction_scale"]
    except: pass
    for k in ["actuator.strength", "actuator.gear", "actuator.gain"]:
        try:
            upd[k] = sys.tree_get(k) * theta["torque_scale"]; break
        except:
            try:
                arr = getattr(sys.actuator, k.split(".")[1]); upd[k] = arr * theta["torque_scale"]; break
            except: pass
    return sys.tree_replace(upd) if upd else sys


# ---------- vectorized K-env stepping ----------

def make_step_from_obs(sys: base.System, obs_mode: str, bpipeline):
    nq, nv = int(sys.q_size()), int(sys.qd_size())

    def _step_one(s, a):
        tmpl_q = jnp.zeros((nq,), s.dtype)
        q, qd = obs_to_q_qd(s, tmpl_q, nq, nv, obs_mode)
        st0 = bpipeline.init(sys, q, qd)
        st1 = bpipeline.step(sys, st0, jnp.clip(a, -1.0, 1.0))
        return q_qd_to_obs(st1.q, st1.qd, nq, nv, obs_mode)

    return _step_one


def build_ensemble_kernels(base_sys: base.System, K: int, key, cfg: Dict, backend: str, obs_mode: str) -> Tuple[Callable, int, any]:
    if backend == "positional":
        from brax.positional import pipeline as bpipeline
    else:
        from brax.generalized import pipeline as bpipeline


    def _sample_theta_prior(key):
        key, k1, k2, k3, k4, k5 = jax.random.split(key, 6)
        ln = lambda k, std: jnp.exp(jax.random.normal(k, ()) * std)
        return key, {
            "mass_scale": jnp.clip(ln(k1, cfg["THETA_PRIOR_MASS_STD"]), cfg["THETA_SCALE_MIN"], cfg["THETA_SCALE_MAX"]),
            "inertia_scale": jnp.clip(ln(k2, cfg["THETA_PRIOR_INERTIA_STD"]), cfg["THETA_SCALE_MIN"], cfg["THETA_SCALE_MAX"]),
            "friction_scale": jnp.clip(ln(k3, cfg["THETA_PRIOR_FRICTION_STD"]), cfg["THETA_SCALE_MIN"], cfg["THETA_SCALE_MAX"]),
            "torque_scale": jnp.clip(ln(k4, cfg["THETA_PRIOR_TORQUE_STD"]), cfg["THETA_SCALE_MIN"], cfg["THETA_SCALE_MAX"]),
            "com_shift": jnp.clip(jax.random.normal(k5, (3,)) * cfg["THETA_PRIOR_COM_STD"],
            -cfg["THETA_COM_ABSMAX"], cfg["THETA_COM_ABSMAX"]),
        }


    step_fns = []
    for _ in range(K):
        key, th = _sample_theta_prior(key)
        sys_k = apply_theta_to_sys(base_sys, th)
        step_fns.append(make_step_from_obs(sys_k, obs_mode, bpipeline))
    step_fns = tuple(step_fns)

    @jax.jit
    def step_allK(s_batch, a_batch):
        outs = [jax.vmap(f)(s_batch, a_batch) for f in step_fns]
        return jnp.stack(outs, axis=1)  # (B,K,D)

    return step_allK, K, key