import asyncio
import json
import re
from typing import List, Dict, Set, Tuple
from collections import defaultdict
from trpo.llm_client_factory import get_async_client


# MODEL_ALIAS = "gpt-4o-mini" # Alias in config.ini
MODEL_ALIAS = "sglang1" # Alias in config.ini

import os
from trpo.prompt.alfworld import SYSTEM_AP_PROMPT, INPUT_AP_PROMPT

class NeuroSymbolicExtractor:
    def __init__(self, max_concurrency: int = 400):
        """
        No batch_size needed. This extractor processes single traces independently.
        Args:
            max_concurrency: Maximum number of concurrent LLM API calls.
        """
        # 初始化 LLM Client via Factory
        self.client = get_async_client(MODEL_ALIAS, max_concurrency=max_concurrency)

    def _log_debug_data(self, unique_keys, results):
        """Debug helper to log raw observation -> triplet mappings."""
        log_path = os.path.join("/workspace/Code/verl-agent/trpo_dataset/ap_log/ap_extraction_raw1.jsonl")
        with open(log_path, "a", encoding="utf-8") as f:
            for (g, o), res in zip(unique_keys, results):
                if res is not None:
                    entry = {"task_goal": g, "observation": o, "triplets": res}
                    f.write(json.dumps(entry, ensure_ascii=False) + "\n")

    async def process_batch(self, traces_map: Dict[str, Dict]) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]:
        """
        Batch processing with GLOBAL OBSERVATION DEDUPLICATION.
        Ensures consistent APs for identical observations across the entire batch.
        
        Args:
            traces_map: Dict {uid: {'goal': str, 'trace': [{'observation': str, 'action': str}, ...]}}
            
        Returns:
            trace_aps_map: Dict {uid: [set(ap_step1), set(ap_step2), ...]}
            metrics: Dict with failure counts
        """
        # 1. Collect Unique Observations
        # Key: (goal, observation_text) -> We include goal as context (optional, but safer)
        # Value: Future (to be filled with triplets)
        unique_obs_map = {} # (goal, obs) -> obs_text
        
        for uid, data in traces_map.items():
            goal = data['goal']
            for step in data['trace']:
                obs = step['observation']
                key = (goal, obs)
                if key not in unique_obs_map:
                    unique_obs_map[key] = obs

        # 2. Parallel Extraction (Unified)
        unique_keys = list(unique_obs_map.keys())
        print(f"[NeuroSymbolic] Extracting APs for {len(unique_keys)} unique observations (Batch Size: {len(traces_map)})")
        
        # Use semaphore inside client, so just gather all
        tasks = [self._llm_extract(k[0], k[1]) for k in unique_keys]
        results = await asyncio.gather(*tasks)
        
        # --- SHADOW LOGGING ---
        # self._log_debug_data(unique_keys, results)
        
        # 3. Build Cache
        # (goal, obs) -> triplets
        obs_cache = dict(zip(unique_keys, results))
        
        # Metrics
        total_obs_failures = sum(1 for r in results if r is None)
        
        # 4. Reconstruct Traces
        trace_aps_map = {}
        traces_with_failures = 0
        total_step_failures = 0
        
        for uid, data in traces_map.items():
            goal = data['goal']
            actions = [step['action'] for step in data['trace']]
            obs_list = [step['observation'] for step in data['trace']]
            
            # Lookup Triplets
            trace_triplets = []
            trace_has_fail = False
            for obs in obs_list:
                triplets = obs_cache.get((goal, obs))
                trace_triplets.append(triplets)
                if triplets is None:
                    trace_has_fail = True
                    total_step_failures += 1
            
            if trace_has_fail:
                traces_with_failures += 1
                
            # Convert to APs (Pure Logic)
            ap_sets = self.process_triplets_to_aps(trace_triplets, actions)
            trace_aps_map[uid] = ap_sets
            
        metrics = {
            "ap_unique_obs_count": len(unique_keys),
            "ap_unique_obs_failures": total_obs_failures,
            "ap_trace_failures_count": traces_with_failures,
            "ap_step_failures_total": total_step_failures
        }
        
        return trace_aps_map, metrics

    async def process_trace(self, goal: str, obs_list: List[str], actions: List[str]) -> Tuple[List[Set[str]], int]:
        """
        Legacy method: Handles full trace processing (Extraction + State Update).
        Now a wrapper around decoupled methods.
        """
        assert len(obs_list) == len(actions), "Obs and Actions length mismatch"
        
        # 1. Extract Triplets (Parallel)
        # Note: This does NOT do cross-trace deduplication. 
        # For batch processing, use process_batch.
        llm_tasks = [self._llm_extract(goal, obs) for obs in obs_list]
        trace_triplets = await asyncio.gather(*llm_tasks)

        # 2. Count failures
        failures = sum(1 for t in trace_triplets if t is None)
        
        # 3. Convert to APs (Sequential Logic)
        trace_ap_sets = self.process_triplets_to_aps(trace_triplets, actions)
        
        return trace_ap_sets, failures

    def process_triplets_to_aps(self, trace_triplets: List[List[Tuple]], actions: List[str]) -> List[Set[str]]:
        """
        Pure Logic: Converts a sequence of triplets (from LLM) and actions into stateful APs.
        No LLM calls here.
        """
        num_steps = len(actions)
        trace_action_aps = [self._parse_action_rule(a) for a in actions]
        
        current_state: Dict[str, Dict[str, str]] = {}
        trace_ap_sets = []

        for i in range(num_steps):
            triplets = trace_triplets[i]
            if triplets is None: triplets = [] # Handle failures gracefully
            
            # Update state
            self._update_state_memory(current_state, triplets)
            
            # Generate State APs
            state_aps = self._flatten_state(current_state)
            
            # Merge Action AP
            step_final_aps = set(state_aps)
            step_final_aps.add(trace_action_aps[i])
            
            trace_ap_sets.append(step_final_aps)
            
        return trace_ap_sets

    def _update_state_memory(self, state_dict, triplets):
        # WebShop Scope Flush 检测
        # self._check_webshop_flush(state_dict, triplets)

        for subj, cat, val in triplets:
            subj, cat, val = self._normalize(subj, cat, val)
            if not subj: continue

            # 标准更新 (自动互斥)
            if subj not in state_dict:
                state_dict[subj] = {}
            state_dict[subj][cat] = val

    def _flatten_state(self, state_dict) -> List[str]:
        aps = []
        for subj, attrs in state_dict.items():
            for cat, val in attrs.items():
                aps.append(f"obs_{subj}_{cat}_{val}")
        return aps

    def _parse_action_rule(self, action_str: str) -> str:
        clean = action_str.lower().rstrip(".,;:!?").strip()
        
        # Strategy: Align with _normalize by removing spaces before numbers
        # "take apple 1" -> "take apple1" -> "act_take_apple1"
        clean = re.sub(r'\s+(\d+)', r'\1', clean)
        
        clean = clean.replace(" ", "_")
        return f"act_{clean}"

    def _normalize(self, s, c, v):
        def clean(t):
            # Remove spaces entirely to merge "bath room"->"bathroom", "apple 1"->"apple1"
            return str(t).lower().strip().replace(" ", "").replace("-", "_").strip("_")
        return clean(s), clean(c), clean(v)

    def _check_webshop_flush(self, state, triplets):
        for s, c, v in triplets:
            s, c, v = self._normalize(s, c, v)
            if s == "self" and c == "page":
                current_page = state.get("self", {}).get("page")
                if current_page and current_page != v:
                    # 页面变了，清除旧 Item
                    for k in list(state.keys()):
                        if k.startswith("item_"): del state[k]

    async def _llm_extract(self, goal, obs, max_retries=3):
        if obs == "Nothing happens.":
            return []
        
        content = INPUT_AP_PROMPT.format(obs=obs)
        messages = [{"role": "system", "content": SYSTEM_AP_PROMPT}, 
                    {"role": "user", "content": content}]

        for attempt in range(max_retries):
            try:
                # --- 1. Initial Call ---
                resp = await self.client.chat.completions.create(
                    messages=messages,
                    response_format={"type": "json_object"},
                    temperature=1.0, 
                    max_tokens=1024, # Increased as requested
                )
                raw_content = resp.choices[0].message.content
                
                # Validation Logic (Shared)
                def validate_and_parse(text):
                    if not text: raise ValueError("Empty response")
                    d = json.loads(text)
                    if "triplets" not in d or not isinstance(d["triplets"], list):
                        raise ValueError("Missing or invalid 'triplets' list")
                    # Strict structure check to catch "too many values to unpack"
                    for t in d["triplets"]:
                        if not isinstance(t, (list, tuple)) or len(t) != 3:
                            raise ValueError(f"Invalid triplet structure: {t}")
                    return d["triplets"]

                return validate_and_parse(raw_content)

            except Exception as e_initial:
                # --- 2. In-Context Repair (One-shot) ---
                # Only try repair if we have some content to show (and it wasn't an empty response error)
                if 'raw_content' in locals() and raw_content:
                    try:
                        # print(f"DEBUG: Triggering Repair for error: {e_initial}")
                        repair_msgs = messages + [
                            {"role": "assistant", "content": raw_content},
                            {"role": "user", "content": "Error: The output is truncated or the JSON format is incorrect. Please regenerate the JSON object correctly."}
                        ]
                        
                        resp_repair = await self.client.chat.completions.create(
                            messages=repair_msgs,
                            response_format={"type": "json_object"},
                            temperature=1.0,
                            max_tokens=1024,
                        )
                        repair_content = resp_repair.choices[0].message.content
                        return validate_and_parse(repair_content)
                    except Exception as e_repair:
                        # If repair fails, log it (optional) and let the outer loop handle backoff
                        # print(f"DEBUG: Repair failed: {e_repair}")
                        pass
                
                # --- 3. Standard Backoff ---
                if attempt < max_retries - 1:
                    wait_time = 1 * (2 ** attempt)
                    await asyncio.sleep(wait_time)
                else:
                    print(f"Error: LLM AP extraction failed after {max_retries} attempts (Final Error: {e_initial})")
                    return None 
        return None