import asyncio
import numpy as np
import torch
import json
import os
from collections import defaultdict
from verl import DataProto

from trpo.aps.llm_extract_aps import NeuroSymbolicExtractor
from trpo.ltl_reward_engine import LTLDualTrackEngine, LTLFormulaToolbox
from trpo.llm_client_factory import get_async_client
from trpo.trace_utils import reconstruct_readable_traces
from trpo.prompt.alfworld import (
    SYSTEM_TRACE_PROMPT, INPUT_TRACE_PROMPT,
    SYSTEM_LTL_LITE_PROMPT, INPUT_LTL_LITE_PROMPT,
    SYSTEM_POS_LTL_PROMPT, SYSTEM_NEG_LTL_PROMPT
)

def _extract_tag_content(text, tag):
    start_tag = f"<{tag}>"
    end_tag = f"</{tag}>"
    end_idx = text.rfind(end_tag)
    if end_idx != -1:
        start_idx = text.rfind(start_tag, 0, end_idx)
        if start_idx != -1:
            return text[start_idx + len(start_tag):end_idx].strip()
    return None

async def _analyze_group_traces(goal: str, traces: list, client, max_retries=3):
    successful_traces = [t for t in traces if t['success']]
    failed_traces = [t for t in traces if not t['success']]

    def format_traces(trace_list, label_prefix):
        formatted = []
        for i, t in enumerate(trace_list):
            steps = []
            for j, s in enumerate(t['trace']):
                obs = s['observation'].replace('\n', ' ')
                steps.append(f"  - Step {j+1}: Obs: {obs} Act: {s['action']}")
            formatted.append(f"{label_prefix}_{i+1} (uid: {t['uid']}):\n" + "\n".join(steps))
        return "\n\n".join(formatted) if formatted else f"No {label_prefix} traces."

    prompt_stage1 = INPUT_TRACE_PROMPT.format(
        goal=goal,
        successful_traces_str=format_traces(successful_traces, "successful_traj"),
        failed_traces_str=format_traces(failed_traces, "failed_traj")
    )
    
    analysis_content = ""
    for attempt in range(max_retries):
        try:
            resp1 = await client.chat.completions.create(
                messages=[
                    {"role": "system", "content": SYSTEM_TRACE_PROMPT},
                    {"role": "user", "content": prompt_stage1}
                ],
                temperature=1.0,
                max_tokens=8196
            )
            analysis_content = resp1.choices[0].message.content
            if not analysis_content:
                raise ValueError("Empty response from Stage 1")
            
            # Extract usage stats
            usage_dict = None
            if hasattr(resp1, 'usage') and resp1.usage:
                usage_dict = {
                    "prompt_tokens": resp1.usage.prompt_tokens,
                    "completion_tokens": resp1.usage.completion_tokens,
                    "total_tokens": resp1.usage.total_tokens
                }

            return analysis_content, usage_dict
            
        except Exception:
            if attempt < max_retries - 1:
                await asyncio.sleep(1 * (2 ** attempt))
                
    return None, None

async def _translate_pos_ltl(goal, ap_list, trace_analysis, client, max_retries=3):
    # Optimization: If milestones are explicitly empty, skip LLM generation
    if trace_analysis:
        try:
            parsed = json.loads(trace_analysis)
            if isinstance(parsed, dict) and "milestones" in parsed and not parsed["milestones"]:
                return "true"
        except (json.JSONDecodeError, TypeError):
            pass

    # Filter APs: Only visible obs_ APs for positive LTL
    filtered_aps = [ap for ap in ap_list if ap.startswith("obs_")]
    ap_list_str = "\n".join([f"- {ap}" for ap in filtered_aps])

    prompt = INPUT_LTL_LITE_PROMPT.format(
        goal=goal,
        ap_list_str=ap_list_str,
        trace_analysis=trace_analysis
    )

    messages = [
        {"role": "system", "content": SYSTEM_POS_LTL_PROMPT},
        {"role": "user", "content": prompt}
    ]
    
    for attempt in range(max_retries):
        try:
            resp = await client.chat.completions.create(
                messages=messages,
                temperature=1.0,
                max_tokens=8196
            )
            content = resp.choices[0].message.content
            if not content:
                raise ValueError("Empty response from Pos LTL generation")
                
            pos_ltl = _extract_tag_content(content, "pos_ltl")
            if pos_ltl is None:
                raise ValueError("No <pos_ltl> tag found")

            # --- Validation with strict_aps=True ---
            clean_pos = LTLFormulaToolbox.preprocess_positive(pos_ltl)
            is_valid_msg = LTLFormulaToolbox.validate_formula(clean_pos, set(filtered_aps), strict_aps=True)
            
            if is_valid_msg:
                # print("Warning: pos LTL validation failed:", attempt)
                messages.append({"role": "assistant", "content": content})
                messages.append({"role": "user", "content": f"Validation Error: {is_valid_msg}\nPlease correct the formula using ONLY the provided APs."})
                if attempt < max_retries - 1:
                    continue
                else:
                    return "true"
            # ---------------------------------------
            
            return pos_ltl
            
        except Exception:
            if attempt < max_retries - 1:
                await asyncio.sleep(1 * (2 ** attempt))
                
    return "true"

async def _translate_neg_ltl(goal, ap_list, trace_analysis, client, max_retries=3):
    # Use all APs for negative LTL
    ap_list_str = "\n".join([f"- {ap}" for ap in ap_list])

    prompt = INPUT_LTL_LITE_PROMPT.format(
        goal=goal,
        ap_list_str=ap_list_str,
        trace_analysis=trace_analysis
    )

    messages = [
        {"role": "system", "content": SYSTEM_NEG_LTL_PROMPT},
        {"role": "user", "content": prompt}
    ]
    
    for attempt in range(max_retries):
        try:
            resp = await client.chat.completions.create(
                messages=messages,
                temperature=1.0,
                max_tokens=8196
            )
            content = resp.choices[0].message.content
            if not content:
                raise ValueError("Empty response from Neg LTL generation")
                
            neg_ltl = _extract_tag_content(content, "neg_ltl")
            if neg_ltl is None:
                raise ValueError("No <neg_ltl> tag found")

            # --- Validation with strict_aps=False (Allow phantom actions) ---
            clean_lines = LTLFormulaToolbox.preprocess_negative(neg_ltl)
            error_found = False
            for line in clean_lines:
                msg = LTLFormulaToolbox.validate_formula(line, set(ap_list), strict_aps=False)
                if msg:
                    error_found = True
                    messages.append({"role": "assistant", "content": content})
                    messages.append({"role": "user", "content": f"Syntax Error: {msg}\nPlease correct the formula."})
                    break
            
            if error_found:
                # print("Warning: neg LTL validation failed:", attempt)
                if attempt < max_retries - 1:
                    continue
                else:
                    return "true"
            # -------------------------------------------------------------
            
            return neg_ltl
            
        except Exception:
            if attempt < max_retries - 1:
                await asyncio.sleep(1 * (2 ** attempt))                
    return "true"

async def _generate_ltl_for_group_two_stage(goal: str, traces: list, ap_list: list, client, examine: bool = False, max_retries=3):
    """
    Two-stage LTL generation: 
    1. Analyze traces for milestones and failures.
    2. Translate analysis to LTL (Pos and Neg separately).
    """
    def debug_print(*args):
        if examine:
            print(*args)

    # Stage 1: Trace Analysis
    debug_print("===LLM STAGE 1 INPUT (Analysis)===")
    analysis_content, analysis_usage = await _analyze_group_traces(goal, traces, client, max_retries)
    # debug_print(f"Stage 1 Output: {analysis_content}...")
    
    if not analysis_content:
        return {"pos_ltl_formula": "true", "neg_ltl_formula": "true", "error": "Stage 1 Max Retries Exceeded", "analysis_usage": None}

    # Stage 2: LTL Generation
    
    json_start_idx = analysis_content.rfind("```json")
    if json_start_idx != -1 :
        json_end_idx = analysis_content.find("```",json_start_idx + 7)
        if json_end_idx != -1:
            json_analysis = analysis_content[json_start_idx + 7:json_end_idx].strip()
        else:
            debug_print("Warning: Invalid json format(from trace analysis). Using raw analysis content.")
            json_analysis = analysis_content  
    else:
        debug_print("Warning: Invalid json format(from trace analysis). Using raw analysis content.")
        json_analysis = analysis_content

    debug_print("===LLM STAGE 2 INPUT (LTL)===")

    # Split analysis for Pos/Neg tasks
    pos_analysis_input = json_analysis
    neg_analysis_input = json_analysis
    
    try:
        parsed = json.loads(json_analysis)
        if isinstance(parsed, dict):
            # Extract milestones for POS
            milestones = parsed.get("milestones", [])
            pos_analysis_input = json.dumps({"milestones": milestones}, ensure_ascii=False, indent=2)
            
            # Extract bad_behaviors for NEG
            bad_behaviors = parsed.get("bad_behaviors", [])
            neg_analysis_input = json.dumps({"bad_behaviors": bad_behaviors}, ensure_ascii=False, indent=2)
            
            debug_print(f"Split Pos Input: {pos_analysis_input}")
            debug_print(f"Split Neg Input: {neg_analysis_input}")
    except Exception as e:
        debug_print(f"Warning: Failed to split JSON analysis ({e}). Using full content for both.")

    # Execute Pos and Neg translation in parallel
    pos_task = _translate_pos_ltl(goal, ap_list, pos_analysis_input, client, max_retries)
    neg_task = _translate_neg_ltl(goal, ap_list, neg_analysis_input, client, max_retries)
    # tmpclient = get_async_client("gpt-4o", max_concurrency=16)
    # pos_task = _translate_pos_ltl(goal, ap_list, pos_analysis_input, tmpclient, max_retries)
    # neg_task = _translate_neg_ltl(goal, ap_list, neg_analysis_input, tmpclient, max_retries)
    pos_ltl, neg_ltl = await asyncio.gather(pos_task, neg_task)

    debug_print(f"Stage 2 Pos Output: {pos_ltl}")
    debug_print(f"Stage 2 Neg Output: {neg_ltl}")

    return {
        "pos_ltl_formula": pos_ltl,
        "neg_ltl_formula": neg_ltl,
        "error": None,
        "trace_analysis": json_analysis,
        "analysis_usage": analysis_usage
    }

async def apply_ltl_penalty(
    batch: DataProto,
    tokenizer,
    penalty: float = -1.0,
    reward: float = 2.0,
    trend_reward: float = 0.0,
    llm_client_name: str = "gpt-4o-mini",
    verbose: bool = False,
    num_examine: int = 0,
    global_step: int = 0,
    log_dir: str = None,
) -> DataProto:
    """
    Orchestrates the LTL Reward Pipeline:
    1. Reconstruct Traces -> 2. Extract APs -> 3. Generate LTL -> 4. Compute Rewards
    """
    if "uid" not in batch.non_tensor_batch:
        raise ValueError("Batch missing 'uid' for grouping.")

    # Initialize client with concurrency limit, removing need for external semaphore
    client = get_async_client(llm_client_name, max_concurrency=16)
    
    # Limit AP extraction concurrency to prevent OOM on LLM server
    # Peak concurrency = max(LTL_gen_concurrency, AP_extract_concurrency) = max(20, 50) = 50
    extractor = NeuroSymbolicExtractor(max_concurrency=512)
    
    # 1. Reconstruct Traces (uid -> {goal, success, trace: [...]})
    traces_map = reconstruct_readable_traces(batch, tokenizer)

    # 2. Extract APs (Batch Processing with Global Deduplication)
    print("Extracting APs (Batch Optimized)...")
    trace_aps_map, ap_metrics = await extractor.process_batch(traces_map)

    # Store AP metrics
    if "metrics" not in batch.meta_info:
        batch.meta_info["metrics"] = {}
    
    batch.meta_info["metrics"]["ltl/ap_unique_obs_count"] = ap_metrics["ap_unique_obs_count"]
    batch.meta_info["metrics"]["ltl/ap_unique_obs_failures"] = ap_metrics["ap_unique_obs_failures"]
    batch.meta_info["metrics"]["ltl/ap_trace_failures_count"] = ap_metrics["ap_trace_failures_count"]
    batch.meta_info["metrics"]["ltl/ap_step_failures_total"] = ap_metrics["ap_step_failures_total"]

    # Store APs in batch for observability
    step_aps_array = np.empty(len(batch), dtype=object)
    for uid, aps_list in trace_aps_map.items():
        for step_data, step_aps in zip(traces_map[uid]['trace'], aps_list):
            step_aps_array[step_data['batch_index']] = step_aps
    # Fill None
    for i in range(len(step_aps_array)):
        if step_aps_array[i] is None: step_aps_array[i] = set()
    batch.non_tensor_batch["step_aps"] = step_aps_array

    # 3. Generate LTL (Grouped by Task/UID)
    # Map group_uid -> list of traj_uids
    group_map = defaultdict(set)
    for i, uid in enumerate(batch.non_tensor_batch["uid"]):
        group_map[uid].add(batch.non_tensor_batch["traj_uid"][i])

    sem_ltl = asyncio.Semaphore(5)
    
    # LTL Metrics Accumulators
    total_neg_ltl_failures = 0 # Formula level
    total_pos_ltl_failures = 0 # Formula level
    total_neg_ltl_compiled = 0
    total_pos_ltl_compiled = 0
    
    groups_with_neg_failures = 0 # Group level
    groups_with_pos_failures = 0 # Group level
    total_groups_with_failures = 0 # Group level
    total_groups_with_generation_errors = 0 # Generation error level

    # Initialize tensors (on CPU first to avoid frequent GPU transfer issues during loop, move later if needed)
    # Actually, batch.batch[...] tensors are on GPU. We should create tensors on same device.
    device = batch.batch.get('responses', torch.device('cpu')).device
    ltl_pos_rewards = torch.zeros(len(batch), dtype=torch.float32, device=device)
    ltl_neg_rewards = torch.zeros(len(batch), dtype=torch.float32, device=device)
    final_ltl_formulas = np.empty(len(batch), dtype=object)

    async def ltl_task(group_uid, traj_uids, verbose, examine):
        async with sem_ltl:
            if not traj_uids: return {}
            
            # Prepare data for this group
            group_traces_data = [traces_map[tid] for tid in traj_uids]
            goal = group_traces_data[0]['goal']
            
            # Collect unique APs seen in this group
            group_unique_aps = set()
            for tid in traj_uids:
                aps_list = trace_aps_map[tid]
                if aps_list is not None:
                    for ap_set in aps_list:
                        group_unique_aps.update(ap_set)
            
            sorted_unique_aps = sorted(list(group_unique_aps))

            # Generate LTL formulas
            result = await _generate_ltl_for_group_two_stage(goal, group_traces_data, sorted_unique_aps, client, examine=examine)
            
            pos_ltl = result.get("pos_ltl_formula")
            neg_ltl = result.get("neg_ltl_formula")
            generation_error = result.get("error")
            trace_analysis = result.get("trace_analysis")
            analysis_usage = result.get("analysis_usage")
            
            # Prepare traces for logging
            traces_log = {}
            for tid in traj_uids:
                if tid in traces_map:
                    traces_log[tid] = [
                        {"obs": s['observation'], "act": s['action']} 
                        for s in traces_map[tid]['trace']
                    ]

            # Metadata for logging
            metadata = {
                "group_uid": group_uid,
                "task_goal": goal,
                "traces": traces_log,
                "ap_list": sorted_unique_aps,
                "pos_ltl": pos_ltl,
                "neg_ltl": neg_ltl,
                "generation_error": generation_error,
                "trace_analysis": trace_analysis,
                "analysis_usage": analysis_usage
            }

            # Instantiate Engine
            engine = LTLDualTrackEngine(
                negative_formula=neg_ltl,
                positive_formula=pos_ltl,
                known_aps=list(group_unique_aps),
                negative_reward=penalty,
                positive_reward=reward,
                trend_reward=trend_reward,
                verbose=verbose,
            )
            
            # Collect stats
            compilation_stats = engine.get_compilation_stats()
            
            # Execution Results Container
            execution_results = [] # List of (batch_idx, neg_r, pos_r, info)
            formula_logs = []      # List of (batch_idx, formula_dict)

            # Process all trajectories in this group
            if engine.valid:
                for tid in traj_uids:
                    # Retrieve data
                    data = traces_map[tid]
                    step_aps_list = trace_aps_map[tid]
                    if step_aps_list is None: continue
                    
                    # Engine handles state tracking and reward shifting internally
                    trajectory_results = engine.process_trajectory(step_aps_list)
                    
                    # === Fallback Mechanism: Enforce Rewards for Successful Traces ===
                    # If the task is successful, the last step IS a milestone, and the path IS a valid trend.
                    # This protects against Pos LTL generation failures (false negatives).
                    if data.get('success', False) and len(trajectory_results) > 0:
                        # 1. Enforce Last Step as Milestone
                        _, _, last_info = trajectory_results[-1]
                        trajectory_results[-1] = (0.0, reward, {**last_info, 'is_milestone': True, 'is_trend': False})
                        
                        # 2. Enforce Prior Steps as Trends (if not already Milestones)
                        for k in range(len(trajectory_results) - 1):
                            n_k, p_k, info_k = trajectory_results[k]
                            if not info_k['is_milestone']:
                                # Upgrade to Trend if not already a Milestone
                                # Note: We use the closure variable 'trend_reward'
                                trajectory_results[k] = (n_k, trend_reward, {**info_k, 'is_trend': True})

                    # Map back to batch indices
                    for i, (neg_r, pos_r, info) in enumerate(trajectory_results):
                        if i >= len(data['trace']): break # Safety check
                        
                        batch_idx = data['trace'][i]['batch_index']
                        execution_results.append((batch_idx, neg_r, pos_r, info))
                        
                        # Log formulas (sparse logging could be optimized here)
                        if pos_ltl or neg_ltl:
                            formula_logs.append((batch_idx, {"pos_ltl": pos_ltl, "neg_ltl": neg_ltl}))
            
            return {
                "stats": compilation_stats,
                "rewards": execution_results,
                "formulas": formula_logs,
                "has_generation_error": generation_error is not None,
                "metadata": metadata,
                "analysis_usage": analysis_usage
            }

    print("Generating LTL formulas and Computing Rewards...")
    # Determine which tasks to examine
    tasks = []
    for i, (gid, tids) in enumerate(group_map.items()):
        examine_this = (i < num_examine)
        tasks.append(ltl_task(gid, tids, verbose, examine_this))

    ltl_task_results = await asyncio.gather(*tasks)
    
    # Log metadata if log_dir is provided
    if log_dir is not None:
        try:
            os.makedirs(log_dir, exist_ok=True)
            log_file = os.path.join(log_dir, "ltl_metadata.jsonl")

            with open(log_file, "a", encoding='utf-8') as f:
                for res in ltl_task_results:
                    if not res or "metadata" not in res: continue
                    meta = res["metadata"]
                    meta["step"] = global_step
                    f.write(json.dumps(meta, ensure_ascii=False) + "\n")
            print(f"LTL metadata logged to {log_file}")
        except Exception as e:
            print(f"Failed to log LTL metadata: {e}")

    # Metrics Counters
    total_milestones = 0
    total_trends = 0
    
    # Token Usage Counters
    total_analysis_prompt_tokens = 0
    total_analysis_completion_tokens = 0

    # Aggregation
    for res in ltl_task_results:
        if not res: continue
        
        if res.get("has_generation_error"):
            total_groups_with_generation_errors += 1

        # Agg Stats
        stats = res["stats"]
        failed_neg = stats["failed_neg_count"]
        failed_pos = stats["failed_pos_count"]
        
        total_neg_ltl_failures += failed_neg
        total_pos_ltl_failures += failed_pos
        total_neg_ltl_compiled += stats["negative_formulas_compiled"]
        total_pos_ltl_compiled += stats["positive_formulas_compiled"]

        # Token Usage
        usage = res.get("analysis_usage")
        if usage:
            total_analysis_prompt_tokens += usage.get("prompt_tokens", 0)
            total_analysis_completion_tokens += usage.get("completion_tokens", 0)

        # Group Level Stats
        if failed_neg > 0: groups_with_neg_failures += 1
        if failed_pos > 0: groups_with_pos_failures += 1
        if failed_neg > 0 or failed_pos > 0: total_groups_with_failures += 1
        
        # Fill Tensors & Stats
        for batch_idx, neg_r, pos_r, info in res["rewards"]:
            ltl_neg_rewards[batch_idx] = neg_r
            ltl_pos_rewards[batch_idx] = pos_r
            
            if info['is_milestone']: total_milestones += 1
            if info['is_trend']: total_trends += 1
            
        # Fill Formulas
        for batch_idx, f_dict in res["formulas"]:
            final_ltl_formulas[batch_idx] = f_dict

    # Calculate Trigger Rates
    total_steps = len(batch)
    neg_trigger_count = (ltl_neg_rewards < 0).sum().item()
    pos_trigger_count = (ltl_pos_rewards > 0).sum().item() # Any positive reward (Milestone OR Trend)
    
    batch.meta_info["metrics"]["ltl/neg_trigger_rate"] = neg_trigger_count / (total_steps + 1e-6)
    batch.meta_info["metrics"]["ltl/pos_trigger_rate"] = pos_trigger_count / (total_steps + 1e-6)
    
    # New Semantic Metrics
    batch.meta_info["metrics"]["ltl/milestone_rate"] = total_milestones / (total_steps + 1e-6)
    batch.meta_info["metrics"]["ltl/trend_rate"] = total_trends / (total_steps + 1e-6)

    # Store LTL metrics
    batch.meta_info["metrics"]["ltl/formula_failures_neg_total"] = total_neg_ltl_failures
    batch.meta_info["metrics"]["ltl/formula_failures_pos_total"] = total_pos_ltl_failures
    batch.meta_info["metrics"]["ltl/group_failures_neg_count"] = groups_with_neg_failures
    batch.meta_info["metrics"]["ltl/group_failures_pos_count"] = groups_with_pos_failures
    batch.meta_info["metrics"]["ltl/group_failures_any_count"] = total_groups_with_failures
    batch.meta_info["metrics"]["ltl/group_generation_error_count"] = total_groups_with_generation_errors
    batch.meta_info["metrics"]["ltl/compiled_neg_total"] = total_neg_ltl_compiled
    batch.meta_info["metrics"]["ltl/compiled_pos_total"] = total_pos_ltl_compiled
    batch.meta_info["metrics"]["ltl/analysis_prompt_tokens"] = total_analysis_prompt_tokens
    batch.meta_info["metrics"]["ltl/analysis_completion_tokens"] = total_analysis_completion_tokens

    # Store split rewards in batch for ablation/analysis
    batch.batch["ltl_pos_rewards"] = ltl_pos_rewards
    batch.batch["ltl_neg_rewards"] = ltl_neg_rewards
    
    # Sum them up for the total reward used in advantage calculation
    batch.batch["ltl_rewards"] = ltl_pos_rewards + ltl_neg_rewards
    
    batch.non_tensor_batch["ltl_formulas"] = final_ltl_formulas
    
    return batch

def compute_trpo_outcome_advantage(
    token_level_rewards: torch.Tensor,
    ltl_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: np.ndarray,
    traj_index: np.ndarray,
    ltl_beta: float = 0.1,
    epsilon: float = 1e-6,
    norm_adv_by_std_in_grpo: bool = True,
    add_ltl_to_reward_in_norm: bool = True,
):
    """
    Computes TRPO advantage using a decoupled "Dual Advantage" strategy.

    Logic:
    1.  **Outcome Advantage (A_outcome)**:
        Standard GRPO on the sparse environment reward (success/fail) + Format Rewards.
        Calculated at the step level (effectively compute_mean_std_cross_steps=True) to ensure 
        local step penalties (like format errors) are properly normalized against the group mean.

    2.  **Process Advantage (A_process)**:
        Directly uses the step-level LTL rewards (dense/shaping).
        
    Args:
        add_ltl_to_reward_in_norm: 
            If True, adds LTL reward to the raw scores BEFORE normalization.
            (Relative Competition: "How much better am I than the group?")
            If False, adds LTL reward to the normalized advantage AFTER normalization.
            (Absolute Correction: "Penalty for violating rules regardless of group performance.")
    """
    # 1. Prepare raw scores
    scores = token_level_rewards.sum(dim=-1) # (bs,)
    
    # Pre-Norm Injection: Add LTL reward to raw scores
    if add_ltl_to_reward_in_norm:
        scores = scores + ltl_beta * ltl_rewards

    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}
    
    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            # TRPO enforces cross_steps=True behavior by including ALL steps.
            id2score[index[i]].append(scores[i])
        
        for idx in id2score:
            # Stack to create a tensor for this group
            group_scores = torch.stack(id2score[idx])
            if len(group_scores) == 1:
                id2mean[idx] = torch.tensor(0.0, device=scores.device)
                id2std[idx] = torch.tensor(1.0, device=scores.device)
            elif len(group_scores) > 1:
                id2mean[idx] = torch.mean(group_scores)
                # Use torch.std to match GRPO (unbiased, ddof=1)
                id2std[idx] = torch.std(group_scores)
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        
        # Normalize scores to get Advantage
        for i in range(bsz):
            if norm_adv_by_std_in_grpo:
                scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
            else:
                scores[i] = scores[i] - id2mean[index[i]]

    # Post-Norm Injection: Add LTL reward to normalized advantage
    if not add_ltl_to_reward_in_norm:
        advantages = scores + ltl_beta * ltl_rewards
    else:
        advantages = scores

    # Broadcast to token level (masked)
    # (bs, 1) * (bs, seq_len) -> (bs, seq_len)
    advantages = advantages.unsqueeze(-1) * response_mask
    return advantages, advantages
