import einops as ei
from functools import partial
from typing import Tuple, NamedTuple

import jax
import jax.numpy as jnp


class Transition_reach(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    value_reach: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray
    g: jnp.ndarray
    h: jnp.ndarray
    phi: jnp.ndarray = None  

@partial(jax.jit)
def calculate_gae_reach4(
    gamma: float,
    gae_lambda: float,
    T_gs: jnp.ndarray,
    T_Vhs: jnp.ndarray,
    done: jnp.ndarray,
    T_hs: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:

    Tp1, nh = T_gs.shape
    T = Tp1 - 1

    def loop(carry, inp):
        ii, gs, hs, Vhs, done_row = inp
        next_Vhs_row, gae_coeffs, pre_done_row = carry

        # Update GAE coeffs. [1] -> [1, λ/(1-λ)] -> [1 λ λ²/(1-λ)] -> [1 λ λ² λ³/(1-λ)]
        gae_coeffs = (jnp.roll(gae_coeffs, 1, axis=0) * gae_lambda * (1 - pre_done_row) +
                      jnp.roll(gae_coeffs, 1, axis=0) * (gae_lambda / (1 - gae_lambda)) * pre_done_row) * (1 - done_row)
        gae_coeffs = gae_coeffs.at[0, :].set(1.0)

        mask = jnp.arange(T + 1) < ii + 1
        mask_h = mask[:, None]

        # DP for Vh.
        # done_row_processed = jnp.where(jnp.isnan(done_row * jnp.inf), 0, done_row * jnp.inf)
        # disc_to_gh = gamma * (next_Vhs_row + done_row_processed)
        disc_to_gh = gamma * next_Vhs_row 

        # 核心计算逻辑：同时考虑hs和gs
        Vhs_row = jnp.maximum(hs, jnp.minimum(gs, disc_to_gh))
        Vhs_row = mask_h * Vhs_row 

        normed_gae_coeffs = gae_coeffs / jnp.sum(gae_coeffs, axis=0)
        Qhs_GAE = jnp.sum(Vhs_row * normed_gae_coeffs, axis=0)

        # Setup Vs_row for next timestep.
        Vhs_row = jnp.roll(Vhs_row, 1, axis=0)
        Vhs_row = Vhs_row.at[0, :].set(Vhs)

        return (Vhs_row, gae_coeffs, done_row), Qhs_GAE

    done = jnp.array(done, dtype=int)
    init_gae_coeffs = jnp.zeros((T + 1, nh))

    init_Vhs = jnp.zeros((T + 1, nh)).at[0, :].set(T_Vhs[T, :])
    init_carry = (init_Vhs, init_gae_coeffs, jnp.zeros(nh, dtype=int))

    ts = jnp.arange(T)[::-1]
    inps = (ts, T_gs[:-1], T_hs[:-1], T_Vhs[1:], done)

    _, Qhs_GAEs = jax.lax.scan(loop, init_carry, inps, reverse=True)
    return Qhs_GAEs - T_Vhs[:-1], Qhs_GAEs


@partial(jax.jit)
def calculate_advantage2(
    gae_nval_gamma_lambda: Tuple[jnp.ndarray, jnp.ndarray, float, float],
    inp
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, float, float], jnp.array]:

    gae, next_value, Gamma, Lambda = gae_nval_gamma_lambda
    transition, done, next_done = inp
    reward = transition.reward
    value = transition.value
    delta = (reward + Gamma * next_value * (1 - next_done)) * (1 - done) - value
    gae = delta + Gamma * Lambda * (1 - next_done) * (1 - done) * gae
    return (gae, value, Gamma, Lambda), gae


@partial(jax.jit)
def calculate_gae2(
    gamma: float,
    gae_lambda: float,
    trajectory_batch,
    done: jnp.ndarray,
    last_value: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:

    next_done = jnp.roll(done, -1, axis=0)
    next_done = next_done.at[-1, :].set(next_done[-2, :])
    _, advantages = jax.lax.scan(
        calculate_advantage2,
        (jnp.zeros_like(last_value), last_value, gamma, gae_lambda),
        (trajectory_batch, done, next_done),
        reverse=True,
        unroll=16,
    )
    return advantages, advantages + trajectory_batch.value


def calculate_done(traj_batch, h_append):
    
    done = (h_append >= 0).astype(jnp.int32)
    return done

@partial(jax.jit)
def calculate_phi_targets(
    gamma: float,
    g_values: jnp.ndarray,
) -> jnp.ndarray:
   
    
    is_target = (g_values < 0).astype(jnp.float32)
    
    
    has_target = jnp.any(is_target, axis=0)  # [batch_size]
    
   
    T, batch_size = g_values.shape
    

    time_indices = jnp.arange(T)[:, None]  # [T, 1]
    future_indices = jnp.arange(T)[None, :]  # [1, T]
    steps = future_indices - time_indices  # [T, T]
    
    
    future_mask = (steps > 0).astype(jnp.float32)  # [T, T]
    
   
    discount_matrix = jnp.power(gamma, steps) * future_mask  # [T, T]
    
   
    is_target_expanded = is_target[None, :, :]  # [1, T, batch_size]
    
   
    propagation_weights = discount_matrix[:, :, None]  # [T, T, 1]
    
   
    phi_propagated = propagation_weights * is_target_expanded  # [T, T, batch_size]
    
    
    phi_discounted = jnp.max(phi_propagated, axis=1)  # [T, batch_size]
    
   
    phi_targets = jnp.maximum(is_target, phi_discounted)  # [T, batch_size]
    
  
    ones = jnp.ones_like(phi_targets)  # [T, batch_size]
    phi_targets = jnp.where(has_target, phi_targets, ones)  
    
    return phi_targets


# @partial(jax.jit)
# def calculate_phi_targets_success(
#     gamma: float,
#     g_values: jnp.ndarray,
#     h_values: jnp.ndarray,
# ) -> Tuple[jnp.ndarray, jnp.ndarray]:
#     """
#     Construct phi targets and mask using the "steps-to-go" definition over safe successful segments.

#     For each time t, find the earliest u >= t such that:
#       - g[u] < 0 (hit target at u)
#       - No unsafe (h >= 0) occurs in (t, u]
#     If such u exists, set target[t] = gamma ** (u - t); otherwise no supervision at t.

#     Includes the hit step itself (u == t => target 1.0).

#     Args:
#       gamma: Discount factor in (0, 1].
#       g_values: [T, B] values for target condition (g < 0 means target reached).
#       h_values: [T, B] values for safety (unsafe when h >= 0).

#     Returns:
#       phi_targets: [T, B] discounted labels (0 where masked out).
#       phi_mask:    [T, B] mask (1.0 where supervised, 0.0 otherwise).
#     """
#     T, B = g_values.shape

#     hit = (g_values < 0)
#     unsafe = (h_values >= 0)

#     # Cumulative count of unsafe to check violations in (t, u]
#     cum_unsafe = jnp.cumsum(unsafe.astype(jnp.int32), axis=0)  # [T, B]

#     t_idx = jnp.arange(T)[:, None]  # [T, 1]
#     u_idx = jnp.arange(T)[None, :]  # [1, T]
#     future = (u_idx >= t_idx)       # [T, T]

#     # Broadcast cumulative sums to [T, T, B] and test violations in (t, u]
#     cum_u = cum_unsafe[u_idx, :]    # [T, T, B]
#     cum_t = cum_unsafe[t_idx, :]    # [T, T, B]

#     unsafe_at_t = unsafe[t_idx, :]  # [T, T, B]
#     viol_between = (cum_u - cum_t) > 0
#     viol_inclusive = viol_between | unsafe_at_t
#     safe_between = ~viol_inclusive   # [T, T, B]

#     cand = (future[:, :, None] & safe_between) & hit[None, :, :]  # [T, T, B]
#     has_cand = jnp.any(cand, axis=1)        # [T, B]
#     u_first = jnp.argmax(cand, axis=1)      # [T, B] earliest valid u

#     steps = u_first - t_idx                  # [T, B]
#     # Use prior target for unlabeled positions: gamma ** 100
#     prior_k = 100
#     prior_target = jnp.power(gamma, prior_k)
#     phi_targets = jnp.where(has_cand, jnp.power(gamma, steps), prior_target)  # [T, B]
#     # Since we now have targets everywhere (real or prior), use full mask of ones
#     phi_mask = jnp.ones_like(phi_targets, dtype=jnp.float32)

#     return phi_targets.astype(jnp.float32), phi_mask


# @partial(jax.jit)
# def calculate_phi_targets_success(
#     gamma: float,
#     g_values: jnp.ndarray,
#     h_values: jnp.ndarray,
# ) -> Tuple[jnp.ndarray, jnp.ndarray]:
#     """
#     Construct *log* phi targets and mask using the "steps-to-go" definition over safe successful segments.

#     对每个时间 t，寻找最早的 u >= t，使得：
#       - g[u] < 0 （在 u 时刻到达目标）
#       - 在 (t, u] 内没有发生 unsafe：h >= 0

#     若存在这样的 u，则：
#       T = u - t
#       phi(t) = gamma ** T
#       本函数返回的是 log_phi(t) = T * log(gamma)

#     若不存在这样的 u，则用先验 T = prior_k（例如 100），
#     即 log_phi_prior = prior_k * log(gamma)。

#     Args:
#       gamma: 折扣因子，(0, 1].
#       g_values: [T, B]，g < 0 表示到达目标。
#       h_values: [T, B]，unsafe 当 h >= 0。

#     Returns:
#       log_phi_targets: [T, B]，log(phi(t)) = T * log(gamma)，或者先验 log 值。
#       phi_mask:        [T, B]，此处给全 1（都有监督，只是有些是先验）。
#     """
#     T, B = g_values.shape

#     hit = (g_values < 0)
#     unsafe = (h_values >= 0)

#     # 累积 unsafe 计数，用来检查 (t, u] 区间是否有 unsafe
#     cum_unsafe = jnp.cumsum(unsafe.astype(jnp.int32), axis=0)  # [T, B]

#     t_idx = jnp.arange(T)[:, None]  # [T, 1]
#     u_idx = jnp.arange(T)[None, :]  # [1, T]
#     future = (u_idx >= t_idx)       # [T, T]

#     # Broadcast cumulative sums to [T, T, B] and test violations in (t, u]
#     cum_u = cum_unsafe[u_idx, :]    # [T, T, B]
#     cum_t = cum_unsafe[t_idx, :]    # [T, T, B]

#     unsafe_at_t = unsafe[t_idx, :]  # [T, T, B]
#     viol_between = (cum_u - cum_t) > 0
#     viol_inclusive = viol_between | unsafe_at_t
#     safe_between = ~viol_inclusive   # [T, T, B]

#     cand = (future[:, :, None] & safe_between) & hit[None, :, :]  # [T, T, B]
#     has_cand = jnp.any(cand, axis=1)        # [T, B]
#     u_first = jnp.argmax(cand, axis=1)      # [T, B] earliest valid u

#     # 步数 T = u - t
#     steps = u_first - t_idx                 # [T, B] （广播）

#     # ---- 这里开始是和你原来不一样的地方：改成 log(phi) 的形式 ----
#     log_gamma = jnp.log(gamma)             # 标量 < 0

#     # 对没有监督的位置，用一个先验步数 prior_k
#     prior_k = 150
#     # log(phi_prior) = prior_k * log_gamma
#     prior_target = prior_k * log_gamma     # 标量

#     # 有合法的 u 时：log_phi = steps * log_gamma
#     # 否则：log_phi = prior_target
#     log_phi_targets = jnp.where(
#         has_cand,
#         steps * log_gamma,                 # [T, B]，log(gamma**steps)
#         prior_target                       # 标量广播到 [T, B]
#     )  # [T, B]

#     # 现在我们在所有位置都有一个 log_phi（真实的或先验），mask 全 1 即可
#     phi_mask = jnp.ones_like(log_phi_targets, dtype=jnp.float32)

#     return log_phi_targets.astype(jnp.float32), phi_mask

@partial(jax.jit)
def calculate_phi_targets_success(
    gamma: float,
    g_values: jnp.ndarray,
    h_values: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Construct *log* phi targets and mask using the "steps-to-go" definition over safe successful segments.

    对每个时间 t，寻找最早的 u >= t，使得：
      - g[u] < 0 （在 u 时刻到达目标）
      - 在 (t, u] 内没有发生 unsafe：h >= 0

    若存在这样的 u，则：
      T = u - t
      phi(t) = gamma ** T
      本函数返回的是 log_phi(t) = T * log(gamma)

    若不存在这样的 u，则用先验 T = prior_k（例如 150），
    即 log_phi_prior = prior_k * log(gamma)。

    另外：在 unsafe 区域 (h >= 0) 上，将 log_phi 目标强制记为 0。

    Args:
      gamma: 折扣因子，(0, 1].
      g_values: [T, B]，g < 0 表示到达目标。
      h_values: [T, B]，unsafe 当 h >= 0。

    Returns:
      log_phi_targets: [T, B]，log(phi(t)) = T * log(gamma)，或先验/0。
      phi_mask:        [T, B]，此处给全 1（都有监督，只是有些是先验或 0）。
    """
    T, B = g_values.shape

    hit = (g_values < 0)
    unsafe = (h_values >= 0)

    # 累积 unsafe 计数，用来检查 (t, u] 区间是否有 unsafe
    cum_unsafe = jnp.cumsum(unsafe.astype(jnp.int32), axis=0)  # [T, B]

    t_idx = jnp.arange(T)[:, None]  # [T, 1]
    u_idx = jnp.arange(T)[None, :]  # [1, T]
    future = (u_idx >= t_idx)       # [T, T]

    # Broadcast cumulative sums to [T, T, B] and test violations in (t, u]
    cum_u = cum_unsafe[u_idx, :]    # [T, T, B]
    cum_t = cum_unsafe[t_idx, :]    # [T, T, B]

    unsafe_at_t = unsafe[t_idx, :]  # [T, T, B]
    viol_between = (cum_u - cum_t) > 0
    viol_inclusive = viol_between | unsafe_at_t
    safe_between = ~viol_inclusive   # [T, T, B]

    cand = (future[:, :, None] & safe_between) & hit[None, :, :]  # [T, T, B]
    has_cand = jnp.any(cand, axis=1)        # [T, B]
    u_first = jnp.argmax(cand, axis=1)      # [T, B] earliest valid u

    # 步数 T = u - t
    steps = u_first - t_idx                 # [T, B] （广播）

    # ---- log(phi) 形式 ----
    log_gamma = jnp.log(gamma)             # 标量 < 0

    # 对没有监督的位置，用一个先验步数 prior_k
    prior_k = 200
    # log(phi_prior) = prior_k * log_gamma
    prior_target = prior_k * log_gamma     # 标量

    # 有合法的 u 时：log_phi = steps * log_gamma
    # 否则：log_phi = prior_target
    log_phi_raw = jnp.where(
        has_cand,
        steps * log_gamma,                 # [T, B]，log(gamma**steps)
        prior_target                       # 标量广播到 [T, B]
    )  # [T, B]

    # 在 unsafe 区域 (h >= 0) 上，将目标强制设为 0
    # unsafe 是 [T, B]
    # log_phi_targets = jnp.where(unsafe, 200 * log_gamma, log_phi_raw)

    # 现在我们在所有位置都有一个 log_phi（真实 / 先验 / 0），mask 全 1 即可
    phi_mask = jnp.ones_like(log_phi_raw, dtype=jnp.float32)

    return log_phi_raw.astype(jnp.float32), phi_mask
