import asyncio
import json
import re
import os
import numpy as np
import torch
from collections import defaultdict
from verl import DataProto
from trpo.llm_client_factory import get_async_client
from trpo.trace_utils import reconstruct_readable_traces
from trpo.prompt.alfworld_ablation import (
    SYSTEM_DIRECT_GROUP_LEVEL_PROMPT, SYSTEM_DIRECT_TRACE_LEVEL_PROMPT, SYSTEM_DIRECT_STEP_LEVEL_PROMPT,
    INPUT_GROUP_PROMPT, INPUT_TRACE_PROMPT, INPUT_STEP_PROMPT
)

# ============================================================================
#  Helpers
# ============================================================================

def extract_json_from_response(content):
    """
    Robustly extracts JSON from LLM response, handling Markdown code blocks.
    """
    try:
        # Try to find JSON block
        start = content.find("```json")
        if start != -1:
            end = content.find("```", start + 7)
            if end != -1:
                json_str = content[start + 7:end].strip()
                return json.loads(json_str)
        
        # Fallback: Try finding the first '{' and last '}'
        start = content.find("{")
        end = content.rfind("}")
        if start != -1 and end != -1:
            json_str = content[start:end+1]
            return json.loads(json_str)
            
        return None
    except json.JSONDecodeError:
        return None

# ============================================================================
#  LLM Scoring Logic
# ============================================================================

async def _score_group(goal, traces, client, max_retries=3, examine=False):
    """
    Mode 1: Scores all traces in a group.
    Returns: (Dict {uid: {'reward_steps': [], 'penalty_steps': []}}, raw_content, is_error)
    """
    traces_text = []
    for t in traces:
        steps_text = []
        for i, s in enumerate(t['trace']):
            # 1-based index for prompt
            steps_text.append(f"    Step {i+1}: Obs: {s['observation']} -> Act: {s['action']}")
        trace_block = f"[Trace ID: {t['uid']}]\n" + "\n".join(steps_text)
        traces_text.append(trace_block)
    
    prompt = INPUT_GROUP_PROMPT.format(goal=goal, traces_str="\n\n".join(traces_text))
    
    messages = [
        {"role": "system", "content": SYSTEM_DIRECT_GROUP_LEVEL_PROMPT},
        {"role": "user", "content": prompt}
    ]

    last_content = ""
    for attempt in range(max_retries):
        try:
            response = await client.chat.completions.create(
                messages=messages,
                response_format={"type": "text"}, # Use text to allow <think>
                temperature=0.2,
                max_tokens=4096
            )
            content = response.choices[0].message.content
            last_content = content
            if examine:
                print(f"\n{'='*20} [EXAMINE LLM RESPONSE (GROUP)] {'='*20}\n{content}\n{'='*60}")

            usage = None
            if hasattr(response, 'usage') and response.usage:
                usage = {"prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens}

            data = extract_json_from_response(content)
            if data and "scores" in data:
                return data["scores"], content, False, usage
            
            # If parsing failed but no exception, it might be an invalid format. Retry if needed or just fail.
            if attempt == max_retries - 1:
                 return {}, content, True, usage
                 
        except Exception as e:
            if attempt == max_retries - 1:
                print(f"[Reward Ablation] Group scoring failed after {max_retries} attempts: {e}")
                return {}, last_content, True, None
            await asyncio.sleep(1 * (2 ** attempt))
            
    return {}, last_content, True, None

async def _score_trace(goal, trace_data, client, max_retries=3, examine=False):
    """
    Mode 2: Scores a single trace.
    Returns: (Dict {'reward_steps': [], 'penalty_steps': []}, raw_content, is_error, usage)
    """
    steps_text = []
    for i, s in enumerate(trace_data['trace']):
        steps_text.append(f"Step {i+1}: Obs: {s['observation']} -> Act: {s['action']}")
    
    prompt = INPUT_TRACE_PROMPT.format(goal=goal, trace_str="\n".join(steps_text))
    
    messages = [
        {"role": "system", "content": SYSTEM_DIRECT_TRACE_LEVEL_PROMPT},
        {"role": "user", "content": prompt}
    ]

    last_content = ""
    for attempt in range(max_retries):
        try:
            response = await client.chat.completions.create(
                messages=messages,
                response_format={"type": "text"},
                temperature=0.2,
                max_tokens=2048
            )
            content = response.choices[0].message.content
            last_content = content
            if examine:
                print(f"\n{'='*20} [EXAMINE LLM RESPONSE (TRACE)] {'='*20}\n{content}\n{'='*60}")

            usage = None
            if hasattr(response, 'usage') and response.usage:
                usage = {"prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens}

            data = extract_json_from_response(content)
            if data:
                return data, content, False, usage
            
            if attempt == max_retries - 1:
                 return {}, content, True, usage

        except Exception as e:
            if attempt == max_retries - 1:
                print(f"[Reward Ablation] Trace scoring failed after {max_retries} attempts: {e}")
                return {}, last_content, True, None
            await asyncio.sleep(1 * (2 ** attempt))

    return {}, last_content, True, None

async def _score_step(context, action, client, max_retries=3, examine=False):
    """
    Mode 3: Scores a single step.
    Returns: (float (1.0, -1.0, or 0.0), raw_content, is_error, usage)
    """
    prompt = INPUT_STEP_PROMPT.format(
        context=context,
        action=action
    )

    messages = [
        {"role": "system", "content": SYSTEM_DIRECT_STEP_LEVEL_PROMPT},
        {"role": "user", "content": prompt}
    ]

    last_content = ""
    for attempt in range(max_retries):
        try:
            response = await client.chat.completions.create(
                messages=messages,
                response_format={"type": "text"},
                temperature=0.2,
                max_tokens=512
            )
            content = response.choices[0].message.content
            last_content = content
            if examine:
                print(f"\n{'='*20} [EXAMINE LLM RESPONSE (STEP)] {'='*20}\n{content}\n{'='*60}")

            data = extract_json_from_response(content)
            
            usage = None
            if hasattr(response, 'usage') and response.usage:
                usage = {"prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens}

            if not data:
                if attempt == max_retries - 1: return 0.0, content, True, usage
                continue
            
            label = data.get("label", "neutral").lower()
            if label == "milestone": return 1.0, content, False, usage
            if label == "penalty": return -1.0, content, False, usage
            return 0.0, content, False, usage
            
        except Exception as e:
            if attempt == max_retries - 1:
                return 0.0, last_content, True, None
            await asyncio.sleep(0.5 * (2 ** attempt))

    return 0.0, last_content, True, None

# ============================================================================
#  Main Entry Point
# ============================================================================

async def apply_llm_reward(
    batch: DataProto,
    tokenizer,
    mode: str = 'trace', # 'group', 'trace', 'step'
    pos_reward: float = 1.0,
    neg_reward: float = -1.0,
    trend_reward: float = 0.0,
    llm_client_name: str = "gpt-4o-mini",
    max_concurrency: int = 16,
    max_retries: int = 3,
    num_examine: int = 0,
    log_dir: str = None,
    global_step: int = 0
) -> DataProto:
    """
    Applies LLM-based rewards to the batch.
    Stores result in batch.batch['llm_rewards'].
    """
    if "uid" not in batch.non_tensor_batch:
        raise ValueError("Batch missing 'uid'.")
    
    print(f"--- Applying Ablation Reward: Mode = {mode.upper()} ---")

    # 1. Setup Client
    actual_concurrency = max_concurrency * 4 if mode == 'step' else max_concurrency
    client = get_async_client(llm_client_name, max_concurrency=actual_concurrency)

    # 2. Reconstruct Traces
    traces_map = reconstruct_readable_traces(batch, tokenizer)
    
    # Build a reverse map for consistent logging across all modes
    # batch_index -> (trace_uid, step_idx_in_trace)
    batch_idx_to_trace = {}
    for uid, data in traces_map.items():
        for step_idx, step_data in enumerate(data['trace']):
            batch_idx_to_trace[step_data['batch_index']] = (uid, step_idx)

    # 3. Prepare Tasks
    tasks = []
    task_map = [] # List of (uid, metadata)

    # Always build group mapping for unified logging
    group_to_tids = defaultdict(set)
    group_to_goal = {}
    
    for i, traj_uid in enumerate(batch.non_tensor_batch['traj_uid']):
        group_uid = batch.non_tensor_batch['uid'][i]
        group_to_tids[group_uid].add(traj_uid)
        # Assuming all traces in a group share the same goal
        # We can extract goal later from traces_map
    
    # Pre-populate group goals
    for gid, tids in group_to_tids.items():
        first_tid = next(iter(tids))
        if first_tid in traces_map:
            group_to_goal[gid] = traces_map[first_tid]['goal']

    if mode == 'group':
        for i, (gid, tids) in enumerate(group_to_tids.items()):
            group_traces = [traces_map[tid] for tid in tids]
            goal = group_traces[0]['goal']
            examine_this = (i < num_examine)
            tasks.append(_score_group(goal, group_traces, client, max_retries, examine=examine_this))
            task_map.append((gid, list(tids)))

    elif mode == 'trace':
        for i, (uid, data) in enumerate(traces_map.items()):
            examine_this = (i < num_examine)
            tasks.append(_score_trace(data['goal'], data, client, max_retries, examine=examine_this))
            task_map.append((uid, None))

    elif mode == 'step':
        if 'prompts' in batch.non_tensor_batch:
            prompts = batch.non_tensor_batch['prompts']
        else:
            prompts = tokenizer.batch_decode(batch.batch['prompts'], skip_special_tokens=True)

        if 'responses' in batch.non_tensor_batch:
            responses = batch.non_tensor_batch['responses']
        else:
            responses = tokenizer.batch_decode(batch.batch['responses'], skip_special_tokens=True)
        
        for i, (p, r) in enumerate(zip(prompts, responses)):
            examine_this = (i < num_examine)
            tasks.append(_score_step(p, r, client, max_retries, examine=examine_this))
            task_map.append((None, i)) # Store batch index directly
    
    # 4. Execute
    print(f"Executing {len(tasks)} LLM scoring tasks...")
    results = await asyncio.gather(*tasks)

    # 5. Unify Results into Trace-Level Scores & Collect Raw Outputs
    # trace_uid -> {'reward_steps': set(), 'penalty_steps': set(), 'errors': int}
    trace_scores_map = defaultdict(lambda: {'reward_steps': set(), 'penalty_steps': set(), 'errors': 0})
    
    # Storage for raw LLM outputs:
    # mode='group': raw_outputs[group_uid] = content_str
    # mode='trace': raw_outputs[trace_uid] = content_str
    # mode='step': raw_outputs[trace_uid] = {step_idx: content_str}
    raw_outputs = {} 

    # Token Usage
    total_prompt_tokens = 0
    total_completion_tokens = 0

    def _sanitize_indices(indices):
        """Helper to ensure indices are a list of integers."""
        if not indices: return []
        if not isinstance(indices, list): return []
        
        valid_indices = []
        for x in indices:
            if isinstance(x, int):
                valid_indices.append(x)
            elif isinstance(x, str) and x.isdigit():
                valid_indices.append(int(x))
            elif isinstance(x, dict) and "step" in x: # Handle edge case: [{"step": 1}]
                try:
                    valid_indices.append(int(x["step"]))
                except (ValueError, TypeError): pass
            # Ignore other types
        return valid_indices

    if mode == 'group':
        for i, (res_dict, content, is_error, usage) in enumerate(results):
            gid, traj_uids = task_map[i]
            raw_outputs[gid] = content # Store raw group response
            
            if usage:
                total_prompt_tokens += usage.get('prompt_tokens', 0)
                total_completion_tokens += usage.get('completion_tokens', 0)
            
            if is_error:
                for uid in traj_uids: trace_scores_map[uid]['errors'] += 1
                continue
            
            for uid in traj_uids:
                if uid in res_dict:
                    trace_scores_map[uid]['reward_steps'].update(_sanitize_indices(res_dict[uid].get('reward_steps', [])))
                    trace_scores_map[uid]['penalty_steps'].update(_sanitize_indices(res_dict[uid].get('penalty_steps', [])))

    elif mode == 'trace':
        for i, (res_dict, content, is_error, usage) in enumerate(results):
            uid, _ = task_map[i]
            raw_outputs[uid] = content # Store raw trace response

            if usage:
                total_prompt_tokens += usage.get('prompt_tokens', 0)
                total_completion_tokens += usage.get('completion_tokens', 0)

            if is_error:
                trace_scores_map[uid]['errors'] += 1
                continue
            
            trace_scores_map[uid]['reward_steps'].update(_sanitize_indices(res_dict.get('reward_steps', [])))
            trace_scores_map[uid]['penalty_steps'].update(_sanitize_indices(res_dict.get('penalty_steps', [])))

    elif mode == 'step':
        for i, (val, content, is_error, usage) in enumerate(results):
            _, batch_idx = task_map[i]
            trace_uid, step_idx_0_based = batch_idx_to_trace[batch_idx]
            
            if trace_uid not in raw_outputs: raw_outputs[trace_uid] = {}
            raw_outputs[trace_uid][step_idx_0_based] = content # Store raw step response

            if usage:
                total_prompt_tokens += usage.get('prompt_tokens', 0)
                total_completion_tokens += usage.get('completion_tokens', 0)

            if is_error:
                trace_scores_map[trace_uid]['errors'] += 1
                continue
            
            step_num = step_idx_0_based + 1 # 1-based
            val = float(val)
            if val > 0.5:
                trace_scores_map[trace_uid]['reward_steps'].add(step_num)
            elif val < -0.5:
                trace_scores_map[trace_uid]['penalty_steps'].add(step_num)

    # 6. Fill Rewards & Prepare Logs (Unified Logic)
    device = batch.batch.get('responses', torch.device('cpu')).device
    llm_rewards = torch.zeros(len(batch), dtype=torch.float32, device=device)
    
    total_pos = 0
    total_neg = 0
    total_trend = 0

    # Populate Rewards Tensor first (using trace_scores_map)
    for uid, trace_data in traces_map.items():
        scores = trace_scores_map[uid]

        r_indices = scores['reward_steps']
        p_indices = scores['penalty_steps']
        
        max_milestone_step = max(r_indices) if r_indices else -1
        
        # Fill Tensor
        trace_steps = trace_data['trace']
        for step_idx, step_data in enumerate(trace_steps):
            current_1_based = step_idx + 1
            val = 0.0
            
            # 1. Positive Component (Milestone OR Trend)
            if current_1_based in r_indices: 
                val += pos_reward
                total_pos += 1
            elif trend_reward > 0.0 and current_1_based < max_milestone_step:
                val += trend_reward
                total_trend += 1

            # 2. Negative Component (Penalty) - Additive
            if current_1_based in p_indices: 
                val += neg_reward
                total_neg += 1
            
            llm_rewards[step_data['batch_index']] = val

    # Prepare Logs (Group-Centric)
    log_data_buffer = []
    if log_dir:
        for gid, tids in group_to_tids.items():
            if not tids: continue
            
            # 1. Goal
            goal = group_to_goal.get(gid, "Unknown")
            
            # 2. Traces
            traces_log_dict = {}
            scores_log_dict = {}
            
            for tid in tids:
                if tid not in traces_map: continue
                # Trace content
                traces_log_dict[tid] = [
                    {"obs": s['observation'], "act": s['action']} 
                    for s in traces_map[tid]['trace']
                ]
                # Scores
                scores_log_dict[tid] = {
                    "reward_steps": sorted(list(trace_scores_map[tid]['reward_steps'])),
                    "penalty_steps": sorted(list(trace_scores_map[tid]['penalty_steps'])),
                    "errors": trace_scores_map[tid]['errors']
                }

            # 3. Model Output (Polymorphic)
            model_output_log = None
            
            if mode == 'group':
                model_output_log = raw_outputs.get(gid, "")
            elif mode == 'trace':
                model_output_log = {tid: raw_outputs.get(tid, "") for tid in tids}
            elif mode == 'step':
                model_output_log = {}
                for tid in tids:
                    step_map = raw_outputs.get(tid, {})
                    # Convert map {0: '...', 1: '...'} to list ['...', '...']
                    max_step = max(step_map.keys()) if step_map else -1
                    step_list = []
                    for k in range(max_step + 1):
                        step_list.append(step_map.get(k, ""))
                    model_output_log[tid] = step_list

            entry = {
                "global_step": global_step,
                "mode": mode,
                "group_uid": gid,
                "goal": goal,
                "traces": traces_log_dict,
                "scores": scores_log_dict,
                "model_output": model_output_log
            }
            log_data_buffer.append(entry)

    # Recalculate total parse errors correctly based on mode
    if mode == 'step':
        parse_errors = sum(s['errors'] for s in trace_scores_map.values())
    else:
        parse_errors = sum(1 for _, _, is_err, _ in results if is_err)

    # 7. Store Rewards
    batch.batch['llm_rewards'] = llm_rewards
    
    # 8. Metrics
    total_steps = len(batch)
    if "metrics" not in batch.meta_info:
        batch.meta_info["metrics"] = {}
    
    batch.meta_info["metrics"][f"ablation/{mode}_pos_count"] = total_pos
    batch.meta_info["metrics"][f"ablation/{mode}_neg_count"] = total_neg
    batch.meta_info["metrics"][f"ablation/{mode}_trend_count"] = total_trend
    batch.meta_info["metrics"][f"ablation/{mode}_pos_rate"] = total_pos / (total_steps + 1e-6)
    batch.meta_info["metrics"][f"ablation/{mode}_neg_rate"] = total_neg / (total_steps + 1e-6)
    batch.meta_info["metrics"][f"ablation/{mode}_parse_error_count"] = parse_errors
    
    # Token usage metrics
    batch.meta_info["metrics"][f"ablation/{mode}_prompt_tokens"] = total_prompt_tokens
    batch.meta_info["metrics"][f"ablation/{mode}_completion_tokens"] = total_completion_tokens

    # 9. Write Logs
    if log_dir and log_data_buffer:
        try:
            os.makedirs(log_dir, exist_ok=True)
            log_file = os.path.join(log_dir, f"llm_reward_{mode}.jsonl")
            with open(log_file, "a", encoding='utf-8') as f:
                for entry in log_data_buffer:
                    f.write(json.dumps(entry, ensure_ascii=False) + "\n")
        except Exception as e:
            print(f"Failed to write LLM reward logs: {e}")
    
    return batch


def compute_reward_outcome_advantage(
    token_level_rewards: torch.Tensor, # The sparse environment rewards
    llm_rewards: torch.Tensor,         # The dense LLM rewards (already scaled)
    response_mask: torch.Tensor,
    index: np.ndarray,
    traj_index: np.ndarray = None,     # Added for signature consistency
    epsilon: float = 1e-6,
    norm_adv_by_std_in_grpo: bool = True,
):
    """
    Computes advantage by combining Sparse and Dense rewards.
    Logic: Total Reward = Sparse + Dense. Then standard GRPO normalization.
    """
    # 1. Total Reward Calculation
    # sparse_scores: (B,) - The outcome reward for this step (usually 0 unless last step) 
    sparse_scores = token_level_rewards.sum(dim=-1) 
    
    # dense_scores: (B,) - The LLM reward for this step
    dense_scores = llm_rewards
    
    # Combine
    total_scores = sparse_scores + dense_scores

    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}
    
    # Standard GRPO/TRPO Normalization Grouping
    with torch.no_grad():
        bsz = total_scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(total_scores[i])
        
        for idx in id2score:
            group_scores = torch.stack(id2score[idx])
            if len(group_scores) == 1:
                id2mean[idx] = torch.tensor(0.0, device=total_scores.device)
                id2std[idx] = torch.tensor(1.0, device=total_scores.device)
            elif len(group_scores) > 1:
                id2mean[idx] = torch.mean(group_scores)
                id2std[idx] = torch.std(group_scores)
        
        normalized_scores = total_scores.clone()
        for i in range(bsz):
            if norm_adv_by_std_in_grpo:
                normalized_scores[i] = (total_scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
            else:
                normalized_scores[i] = total_scores[i] - id2mean[index[i]]

    # Broadcast to token level
    # (B, 1) * (B, T) -> (B, T)
    advantages = normalized_scores.unsqueeze(-1) * response_mask
    return advantages, advantages