import jax
import jax.numpy as jnp
from rl.gae import Transition_reach
import optax

# Gradient utilities for combining multi-objective policy updates
def _tree_add(a, b):
    return jax.tree_util.tree_map(lambda x, y: x + y, a, b)


def _tree_sub(a, b):
    return jax.tree_util.tree_map(lambda x, y: x - y, a, b)


def _tree_scale(a, s):
    return jax.tree_util.tree_map(lambda x: x * s, a)


def _tree_dot(a, b):
    la, _ = jax.tree_util.tree_flatten(a)
    lb, _ = jax.tree_util.tree_flatten(b)
    return sum([jnp.vdot(x, y) for x, y in zip(la, lb)])


def _tree_norm(a):
    return jnp.sqrt(jnp.maximum(_tree_dot(a, a).real, 0.0))


def _proj(u, v, eps=1e-12):
    denom = jnp.maximum(_tree_dot(v, v).real, eps)
    coef = _tree_dot(u, v).real / denom
    return _tree_scale(v, coef)


def _compute_policy_terms(apply_fn, params, traj_batch):
    pi = apply_fn(params, traj_batch.obs)
    log_prob = pi.log_prob(traj_batch.action)
    ratio = jnp.exp(log_prob - traj_batch.log_prob)
    entropy = pi.entropy()
    return ratio, entropy


def _normalize(x, eps=1e-8):
    m = jnp.mean(x)
    v = jnp.mean((x - m) ** 2)
    std = jnp.sqrt(jnp.maximum(v, eps))
    return (x - m) / (std + eps)


def _masked_mean(x, mask, eps=1e-8):
    wsum = jnp.sum(mask)
    return jnp.sum(x * mask) / (wsum + eps)


def _normalize_adv(x, mask=None, eps=1e-8):
    if mask is None:
        return _normalize(x, eps)
    m = _masked_mean(x, mask, eps)
    v = _masked_mean((x - m) ** 2, mask, eps)
    std = jnp.sqrt(jnp.maximum(v, eps))
    return (x - m) / (std + eps)


def _entropy_grads(params, apply_fn, traj_batch, ent_coef):
    def ent_loss(p):
        pi = apply_fn(p, traj_batch.obs)
        return -ent_coef * jnp.mean(pi.entropy())
    loss_val, grads = jax.value_and_grad(ent_loss)(params)
    return grads, loss_val


def compute_gradients_R(params, apply_fn, traj_batch, advantages_reach_gamma, clip_eps, mask=None):
    adv = _normalize_adv(advantages_reach_gamma, mask)
    def loss_fn(p):
        pi = apply_fn(p, traj_batch.obs)
        log_prob = pi.log_prob(traj_batch.action)
        ratio = jnp.exp(log_prob - traj_batch.log_prob)
        s1 = ratio * adv
        s2 = jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv
        surr = jnp.maximum(s1, s2)
        if mask is None:
            return jnp.mean(surr)
        else:
            return _masked_mean(surr, mask)
    loss, grads = jax.value_and_grad(loss_fn)(params)
    return grads, loss


def compute_gradients_C(params, apply_fn, traj_batch, advantages_V, clip_eps, mask=None):
    adv = _normalize_adv(advantages_V, mask)
    def loss_fn(p):
        pi = apply_fn(p, traj_batch.obs)
        log_prob = pi.log_prob(traj_batch.action)
        ratio = jnp.exp(log_prob - traj_batch.log_prob)
        s1 = ratio * adv
        s2 = jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv
        surr = jnp.maximum(s1, s2)
        if mask is None:
            return jnp.mean(surr)
        else:
            return _masked_mean(surr, mask)
    loss, grads = jax.value_and_grad(loss_fn)(params)
    return grads, loss


def combine_gradients_PCRPO(grads_R, grads_C, eps):
    dot = _tree_dot(grads_R, grads_C)
    nR = _tree_norm(grads_R)
    nC = _tree_norm(grads_C)
    cos_theta = dot / (nR * nC + 1e-12)
    conflict = cos_theta < (-eps)

    def proj_rule():
        gR_proj = _tree_sub(grads_R, _proj(grads_R, grads_C))
        gC_proj = _tree_sub(grads_C, _proj(grads_C, grads_R))
        return _tree_add(gR_proj, gC_proj)

    def sum_rule():
        return _tree_add(grads_R, grads_C)

    g_comb = jax.lax.cond(conflict, proj_rule, sum_rule)
    nU = _tree_norm(g_comb)
    return g_comb, cos_theta, conflict.astype(jnp.float32), nR, nC, nU

def _env_step(env, env_params, runner_state, _):
    (train_state_policy, train_state_energy, train_state_reach, train_state_phi,
     last_env_state, last_obs, rng) = runner_state

    # SELECT ACTION
    rng, _rng = jax.random.split(rng)
    pi = train_state_policy.apply_fn(train_state_policy.params, last_obs)
    value = train_state_energy.apply_fn(train_state_energy.params, last_obs)
    value_reach = train_state_reach.apply_fn(train_state_reach.params, last_obs)
    phi = train_state_phi.apply_fn(train_state_phi.params, last_obs)


    action = pi.sample(seed=_rng)
    log_prob = pi.log_prob(action)

    # STEP ENV
    rng, _rng = jax.random.split(rng)
    env_num = last_obs.shape[0]
    rng_step = jax.random.split(_rng, env_num)
    obsv, env_state, reward, done, info = jax.vmap(
        env.step, in_axes=(0, 0, 0, None)
    )(rng_step, last_env_state, action, env_params)

    transition = Transition_reach(
        done, action, value, value_reach, reward, log_prob, last_obs, info,
        last_env_state.g, last_env_state.h, phi
    )
    runner_state = (train_state_policy, train_state_energy, train_state_reach, train_state_phi,
                    env_state, obsv, rng)
    return runner_state, transition


def _rapcppo_update(config, update_state, ent):
    # No Lagrange; combine reach_gamma and value advantages directly in gradient space
    (train_state_policy, train_state_energy, train_state_reach, train_state_phi,
     traj_batch, advantages_reach, targets_reach, advantages_V, targets_V, phi_targets, phi_mask, rng) = update_state
    rng, _rng = jax.random.split(rng)

    def _update_minbatch(train_state, batch_info):
        train_state_policy, train_state_energy, train_state_reach, train_state_phi = train_state
        traj_batch, advantages_reach, targets_reach, advantages_V, targets_V, phi_targets, phi_mask = batch_info

        def _loss_fn_reach(params, traj_batch, targets_reach):
            
            value_reach = train_state_reach.apply_fn(params, traj_batch.obs)

            value_pred_clipped_reach = traj_batch.value_reach + (
                    value_reach - traj_batch.value_reach
            ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
            value_losses_reach = jnp.square(value_reach - targets_reach)
            value_losses_clipped_reach = jnp.square(value_pred_clipped_reach - targets_reach)
            value_loss_reach = (
                    0.5 * jnp.maximum(value_losses_reach, value_losses_clipped_reach).mean()
            )

            total_loss = config["VF_COEF"] * value_loss_reach
            return total_loss, value_loss_reach

        def _loss_fn_energy(params, traj_batch, targets_V):
            value = train_state_energy.apply_fn(params, traj_batch.obs)

            value_pred_clipped = traj_batch.value + (
                    value - traj_batch.value
            ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
            value_losses = jnp.square(value - targets_V)
            value_losses_clipped = jnp.square(value_pred_clipped - targets_V)
            value_loss_V = (
                    0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
            )

            total_loss = config["VF_COEF"] * value_loss_V
            return total_loss, value_loss_V

        def _loss_fn_phi(params, traj_batch, phi_targets, phi_mask):
            # Masked MSE over safe-success segments; no PPO-style clipping
            phi_pred = train_state_phi.apply_fn(params, traj_batch.obs)
            errs = jnp.square(phi_pred - phi_targets)
            # bce = optax.sigmoid_binary_cross_entropy(phi_pred, phi_targets)
            # masked_sum = (bce * phi_mask).sum()
            # huber = optax.huber_loss(phi_pred, phi_targets, delta=1.0)
            # masked_sum = (huber * phi_mask).sum()
            masked_sum = (errs * phi_mask).sum()
            norm = phi_mask.sum() + 1e-8
            phi_loss = 0.5 * (masked_sum / norm)
            total_loss = config["VF_COEF"] * phi_loss
            return total_loss, phi_loss

        # Partitioned policy gradients: gamma region vs mixed region
        eps = config.get("PCRPO_EPS", 0.0)
        g_bound = config.get("G_BOUND", 300)
        task_prob = config.get("TASK_PROB", 0.8)

        value_reach_now = train_state_reach.apply_fn(train_state_reach.params, traj_batch.obs)
        phi_now = train_state_phi.apply_fn(train_state_phi.params, traj_batch.obs)
        phi_now = jnp.exp(phi_now)
        # Partition rule: if condition holds, only optimize reach (gamma region);
        # otherwise use mixed gradient (reach + value).
        # reach_mask_bool_1 = value_reach_now > (-task_prob * g_bound * phi_now)
        reach_mask_bool = value_reach_now > (-task_prob * g_bound * phi_now)
        mask_gamma = reach_mask_bool.astype(jnp.float32)
        mask_mixed = 1.0 - mask_gamma

        # mask_gamma = jnp.ones_like(value_reach_now)
        # mask_mixed = jnp.ones_like(value_reach_now)

        phi = jnp.maximum(phi_now, 0.01)
        phi = jax.lax.stop_gradient(phi)
 

        # # Region A: gamma region uses reach advantage with gamma
        grads_R_gamma, loss_R_gamma = compute_gradients_R(
            train_state_policy.params, train_state_policy.apply_fn, traj_batch, advantages_reach/phi, config["CLIP_EPS"], mask_gamma
        )
        # Region B: mixed region combines reach and cost advantages via PCRPO
        grads_R_mixed, loss_R_mixed = compute_gradients_R(
            train_state_policy.params, train_state_policy.apply_fn, traj_batch, advantages_reach/phi, config["CLIP_EPS"], mask_mixed
        )
        # Region A: gamma region uses reach advantage with gamma
        # grads_R_gamma, loss_R_gamma = compute_gradients_R(
        #     train_state_policy.params, train_state_policy.apply_fn, traj_batch, advantages_reach, config["CLIP_EPS"], mask_gamma
        # )
        # # Region B: mixed region combines reach and cost advantages via PCRPO
        # grads_R_mixed, loss_R_mixed = compute_gradients_R(
        #     train_state_policy.params, train_state_policy.apply_fn, traj_batch, advantages_reach, config["CLIP_EPS"], mask_mixed
        # )
        grads_C_mixed, loss_C_mixed = compute_gradients_C(
            train_state_policy.params, train_state_policy.apply_fn, traj_batch, advantages_V, config["CLIP_EPS"], mask_mixed
        )
        g_comb_mixed, cos_theta, conflict, nR, nC, nU = combine_gradients_PCRPO(grads_R_mixed, grads_C_mixed, eps)

        grads_ent, ent_loss = _entropy_grads(train_state_policy.params, train_state_policy.apply_fn, traj_batch, ent)
        grads_policy = _tree_add(_tree_add(grads_R_gamma, g_comb_mixed), grads_ent)

        grad_fn = jax.value_and_grad(_loss_fn_phi, has_aux=True)
        total_loss_phi, grads = grad_fn(
            train_state_phi.params, traj_batch, phi_targets, phi_mask
        )
        train_state_phi = train_state_phi.apply_gradients(grads=grads)

        grad_fn = jax.value_and_grad(_loss_fn_reach, has_aux=True)
        total_loss_reach, grads = grad_fn(
            train_state_reach.params, traj_batch, targets_reach
        )
        train_state_reach = train_state_reach.apply_gradients(grads=grads)
        
        grad_fn = jax.value_and_grad(_loss_fn_energy, has_aux=True)
        total_loss_energy, grads = grad_fn(
            train_state_energy.params, traj_batch, targets_V
        )
        train_state_energy = train_state_energy.apply_gradients(grads=grads)

        # Apply combined policy grads
        train_state_policy = train_state_policy.apply_gradients(grads=grads_policy)

        # Compose losses for logging
        actor_loss = (loss_R_gamma + loss_R_mixed + loss_C_mixed) / 3.0
        entropy_mean = -ent_loss / jnp.maximum(ent, 1e-8)

        return (train_state_policy, train_state_energy, train_state_reach, train_state_phi), {
            "actor_loss": actor_loss,
            "entropy_loss": entropy_mean,
            "energy_loss": total_loss_energy[1],
            "reach_loss": total_loss_reach[1],
            "phi_loss": total_loss_phi[1],
            "pcrpo_cos_theta": cos_theta,
            "pcrpo_conflict_ratio": conflict,
            "grad_norm_R": nR,
            "grad_norm_C": nC,
            "grad_norm_update": nU,
            "gamma_ratio": jnp.mean(mask_gamma),
        }

    rng, _rng = jax.random.split(rng)
    batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
    assert (
            batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
    ), "batch size must be equal to number of steps * number of envs"
    permutation = jax.random.permutation(_rng, batch_size)
    batch = (traj_batch, advantages_reach, targets_reach, advantages_V, targets_V, phi_targets, phi_mask)
    batch = jax.tree_util.tree_map(
        lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
    )
    shuffled_batch = jax.tree_util.tree_map(
        lambda x: jnp.take(x, permutation, axis=0), batch
    )
    minibatches = jax.tree_util.tree_map(
        lambda x: jnp.reshape(
            x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
        ),
        shuffled_batch,
    )

    rng, _rng = jax.random.split(rng)
    train_state, total_loss = jax.lax.scan(
        _update_minbatch, (train_state_policy, train_state_energy, train_state_reach, train_state_phi), minibatches
    )
    update_state = (train_state[0], train_state[1], train_state[2], train_state[3],
                    traj_batch, advantages_reach, targets_reach, advantages_V, targets_V, phi_targets, phi_mask, rng)
    return update_state, total_loss
