import os
import json
import math
import collections
import argparse
from openai import OpenAI
from tqdm import tqdm
import random
from openai import AzureOpenAI
import concurrent.futures

import requests
import re
from collections import namedtuple, defaultdict 
from concurrent.futures import as_completed
import warnings


warnings.filterwarnings('ignore', message='Unverified HTTPS request')

class CustomQwenClient:

    def __init__(self, api_url, model_name):
        self.api_url = api_url
        self.model_name = model_name
        self.chat = self._Chat(self)

    class _Chat:
        def __init__(self, parent):
            self.parent = parent
            self.completions = CustomQwenClient._Completions(parent)

    class _Completions:
        def __init__(self, parent):
            self.parent = parent

        def create(self, model, messages, temperature=0.7, max_tokens=512, **kwargs):

            print("--- Inside Adapter: create() method called ---")

            api_messages = [dict(role=m["role"], content=m["content"]) for m in messages]
            if api_messages and api_messages[-1]["role"] == "user":
                api_messages[-1]["content"] += "/no_think"
            else:
                api_messages.append({"role": "user", "content": "/no_think"})

            headers = {"Accept": "application/json", "Content-Type": "application/json"}
            data = {
                "model": self.parent.model_name,
                "messages": api_messages,
                "stream": False,
                "enable_chain_of_thought": False,
                "temperature": temperature,
                "max_tokens": max_tokens,
            }
            if "n" in kwargs:
                data["n"] = kwargs["n"]
            for k in ("top_p", "presence_penalty", "frequency_penalty", "response_format"):
                if k in kwargs:
                    data[k] = kwargs[k]

            try:
                print("Adapter: Sending POST request...")
                resp = requests.post(
                    self.parent.api_url,
                    headers=headers,
                    data=json.dumps(data),
                    verify=False,
                    timeout=500,
                )
                resp.raise_for_status()
                print("Adapter: Received successful response from API.")

                resp_json = resp.json()

                Message = namedtuple("Message", ["content"])
                Choice = namedtuple("Choice", ["message"])
                MockCompletion = namedtuple("MockCompletion", ["choices"])

                mapped = []
                for ch in resp_json.get("choices", []):
                    raw = None
                    if isinstance(ch, dict):
                        msg = ch.get("message")
                        if isinstance(msg, dict):
                            raw = msg.get("content")
                        if raw is None:
                            raw = ch.get("text")

                    cleaned = re.sub(r"<think>.*?</think>", "", str(raw or ""), flags=re.DOTALL).strip()
                    mapped.append(Choice(message=Message(content=cleaned)))

                if not mapped:
                    mapped = [Choice(message=Message(content=""))]

                return MockCompletion(choices=mapped)

            except requests.exceptions.RequestException as e:
                print(f"❌ Adapter Error: Failed to call custom Qwen API. Details: {e}")
                return None

class Config:

    API_CONFIGS = {
    }

    ANSWERER_MODEL_KEY = ""
    QUESTIONER_MODEL_KEYS = []
    LLM_PARSER_MODEL_KEY = ""
    
    GOLDEN_JSON_FOLDER = "state_vectors_v3"
    WORLD_MODEL_FILE = "world_model_v3.json"
    OUTPUT_GRPO_FILE = "grpo_preference_dataset_1000_v3.json"
    OUTPUT_GRPO_ABLATION_FILE = "grpo_counting_dataset_1000_v3.json" 
    OUTPUT_SFT_FILE = "sft_natural_language_summary_dataset_1000_v3.json"

    MAX_TURNS_PER_DIALOGUE = 15
    NUM_QUESTIONS_PER_MODEL = 4
    NUM_SAMPLED_QUESTIONERS = 4 
    ENTROPY_THRESHOLD_RATIO = 0.05
    TEMPLATE_PROPOSAL_PROBABILITY = 0.2

    MODEL_PRICES = {

    }


    STYLE_TEMPLATES = {
        "professional_blue": {
            "description": "a professional blue and white style, suitable for corporate or academic presentations",
            "attributes": {
                "style_theme": "professional_light",
                "background_color": "#FFFFFF",
                "component_fill_color": "#EAF1FD",
                "component_border_color": "#B4C7E7",
                "font_family": "Helvetica, Arial, sans-serif",
                "text_color": "#000000"
            }
        },
        "academic_grayscale": {
            "description": "a minimalist grayscale theme, perfect for academic papers and formal publications",
            "attributes": {
                "style_theme": "minimalist_grayscale",
                "background_color": "#FFFFFF",
                "component_fill_color": "#F0F0F0",
                "component_border_color": "#666666",
                "font_family": "Times New Roman, serif",
                "text_color": "#000000"
            }
        },
        "vibrant_tech": {
            "description": "a modern and vibrant dark mode style, great for tech startup presentations",
            "attributes": {
                "style_theme": "dark_mode_vibrant",
                "background_color": "#1A1A1A",
                "component_fill_color": "#2D2D2D",
                "component_border_color": "#00BFFF", 
                "font_family": "Roboto, sans-serif",
                "text_color": "#EAEAEA"
            }
        },
        "blueprint_schematic": {
            "description": "a technical blueprint style with a dark blue background, ideal for engineering schematics or software architecture diagrams",
            "attributes": {
                "style_theme": "technical_blueprint",
                "background_color": "#0A2342", 
                "component_fill_color": "#183454", 
                "component_border_color": "#00FFFF", 
                "font_family": "Courier New, monospace", 
                "text_color": "#F0F0F0" 
            }
        },
        "warm_organic": {
            "description": "a warm and inviting theme with earthy tones, suitable for biology, environmental science, or a softer presentation feel",
            "attributes": {
                "style_theme": "natural_warm",
                "background_color": "#FFF8E7", 
                "component_fill_color": "#E8F5E9", 
                "component_border_color": "#8D6E63",
                "font_family": "Verdana, Geneva, sans-serif", 
                "text_color": "#4E342E"
            }
        }
    }

    SCHEMA_MAPPER = {
        # Global Properties
        "topic": "global_properties.topic",
        "subject": "global_properties.topic", 
        "concept": "global_properties.topic", 
        "purpose": "global_properties.purpose",
        "goal": "global_properties.purpose", 
        "target_audience": "global_properties.target_audience",
        "audience": "global_properties.target_audience", 
        "complexity_level": "global_properties.complexity_level",
        "detail_level": "global_properties.complexity_level", 
        "background_color": "global_properties.background_color",
        "font_family": "global_properties.font_family",
        "title_text": "global_properties.title.text",
        "title": "global_properties.title.text",
        "diagram_title": "global_properties.title.text",
        "main_title": "global_properties.title.text",       
        "domain": "global_properties.domain",
        "field": "global_properties.domain", 
        "visual_format": "global_properties.visual_format",
        "diagram_type": "global_properties.diagram_type",
        "type": "global_properties.diagram_type", 
        "layout_grid": "global_properties.layout_grid",
        "layout": "global_properties.layout_grid", 
        "style_theme": "global_properties.style_theme",
        "theme": "global_properties.style_theme", 
        "title_is_present": "global_properties.title.is_present",
        "has_title":        "global_properties.title.is_present",
        "show_title":       "global_properties.title.is_present",
        "global_title":     "global_properties.title.text",
        "figure_title":     "global_properties.title.text",
        
        # Component Properties (Generic)
        "component_type": "component.type",
        "component_shape": "component.geometry.shape",
        "shape": "component.geometry.shape", 

        "component_fill_color": "component.styling.fill_color",
        "fill_color": "component.styling.fill_color", 
        "component_border_color": "component.styling.border_color",
        "border_color": "component.styling.border_color", 
        "component_text_color": "component.text_properties.text_color",
        "text_color": "component.text_properties.text_color", 

        "component_border_style": "component.styling.border_style",
        "border_style": "component.styling.border_style", 
        "component_font_weight": "component.text_properties.font_weight",
        "font_weight": "component.text_properties.font_weight", 
        "component_label": "component.label",
        
        # Connection Properties (Generic)
        "connection_line_style": "connection.line_properties.style",
        "line_style": "connection.line_properties.style", 
        "connection_arrowhead_end": "connection.arrowhead.end_type",
        "arrowhead": "connection.arrowhead.end_type", 
        "connection_arrowhead": "connection.arrowhead.end_type",
        "connection_label": "connection.label.text",
        "connection_from_id": "connection.from_id",
        "from_id": "connection.from_id",
        "source_id": "connection.from_id",
        "connection_label_position": "connection.label.position",
        "connection_label_color": "connection.label.text_color",
        "connection_line_type": "connection.line_properties.type",
        "connection_line_color": "connection.line_properties.color",
        "connection_line_width": "connection.line_properties.width",
        "connection_arrowhead_start": "connection.arrowhead.start_type",
        "connection_arrowhead_size": "connection.arrowhead.size",
        "connection_to_id": "connection.to_id",
        "to_id": "connection.to_id",
        "target_id": "connection.to_id",
        "connection_from_id": "connection.from_id",
        "from_id":            "connection.from_id",
        "source_id":          "connection.from_id",
        "start_id":           "connection.from_id",
        "connection_to_id":   "connection.to_id",
        "to_id":              "connection.to_id",
        "target_id":          "connection.to_id",
        "end_id":             "connection.to_id",

        "layout_constraint_type": "layout_constraint.type",
        "layout_alignment": "layout_constraint.alignment_type",
        "layout_padding": "layout_constraint.padding",
        "layout_distribution": "layout_constraint.distribution_type",
        "layout_arrangement": "layout_constraint.arrangement",

        # Meta Properties
        "style_template": "style_template", 
    }


class TokenTracker:
    def __init__(self, price_table: dict):

        self.price_table = {k.lower(): v for k, v in (price_table or {}).items()}
        self.totals = defaultdict(lambda: {"prompt": 0, "completion": 0, "total": 0, "cost": 0.0})

    def add_usage(self, model_key: str, prompt_tokens: int = 0, completion_tokens: int = 0, total_tokens: int = None):
        pt = int(prompt_tokens or 0)
        ct = int(completion_tokens or 0)
        tt = int(total_tokens if total_tokens is not None else (pt + ct))

        entry = self.totals[model_key]
        entry["prompt"] += pt
        entry["completion"] += ct
        entry["total"] += tt

        mk = (model_key or "").lower()
        for k, price in self.price_table.items():
            if k in mk:
                entry["cost"] += (pt / 1_000_000) * price["prompt"]
                entry["cost"] += (ct / 1_000_000) * price["completion"]
                break

    def add_from_response(self, model_key: str, resp):
        if resp is None:
            return

        usage = None


        usage = getattr(resp, "usage", None)


        if usage is None and hasattr(resp, "to_dict"):
            try:
                d = resp.to_dict()
                if isinstance(d, dict):
                    usage = d.get("usage")
            except Exception:
                pass

        if usage is None and isinstance(resp, dict):
            usage = resp.get("usage")

        if usage:
            pt = getattr(usage, "prompt_tokens", None) or (usage.get("prompt_tokens", 0) if isinstance(usage, dict) else 0)
            ct = getattr(usage, "completion_tokens", None) or (usage.get("completion_tokens", 0) if isinstance(usage, dict) else 0)
            tt = getattr(usage, "total_tokens", None) or (usage.get("total_tokens", pt + ct) if isinstance(usage, dict) else (pt + ct))
            self.add_usage(model_key, pt, ct, tt)


    def snapshot_and_reset(self):

        snap = json.loads(json.dumps(self.totals))
        self.totals = defaultdict(lambda: {"prompt": 0, "completion": 0, "total": 0, "cost": 0.0})
        return snap

class WorldModel:
    def __init__(self, model_path):
        print(f"Loading world model from {model_path}...",flush=True)
        with open(model_path, 'r', encoding='utf-8') as f:
            self.model = json.load(f)
        self.attribute_entropies = self._precompute_entropies()
        self.total_initial_entropy = sum(self.attribute_entropies.values())
        print(f"World model loaded. Total initial entropy: {self.total_initial_entropy:.2f} bits.",flush=True)

    def _precompute_entropies(self):
        entropies = collections.defaultdict(float)
        distributions = self.model.get("prior_distributions", {})
        for attr, dist in distributions.items():
            entropy = -sum(p * math.log2(p) for p in dist.values() if p > 0)
            entropies[attr] = entropy
        return entropies


class SemanticParser:
    def __init__(self, mode='llm', llm_parser_config=None):
        self.mode = mode
        self.token_tracker = None  
        if self.mode == 'llm' and llm_parser_config:
            self.client = llm_parser_config['client']
            self.model_name = llm_parser_config['model_name']
        print(f"SemanticParser initialized in '{self.mode}' mode.",flush=True)

    def parse(self, answer: str):
        if self.mode == 'llm':
            return self._llm_parse(answer)
        else:
            return self._rule_parse(answer)
    
    def set_token_tracker(self, tracker):
        self.token_tracker = tracker

    def _rule_parse(self, answer: str):
        constraints = []
        answer_lower = answer.lower()
        if "flowchart" in answer_lower: constraints.append(('diagram_type', 'flowchart'))
        if "schematic" in answer_lower: constraints.append(('diagram_type', 'schematic'))
        if "rounded rectangle" in answer_lower: constraints.append(('component_shape', 'rounded_rectangle'))
        if "sounds good" in answer_lower or "yes, please" in answer_lower or "use that style" in answer_lower:
            pass
        return constraints


    def _get_llm_prompt_template(self):
        return """
        You are a highly precise linguistic analysis tool. Your *only* function is to extract structured constraints from a user's answer.

        **Instructions:**
        1.  Analyze the "User's Answer".
        2.  Identify any core concepts that match the "CONCEPT DEFINITION".
        3.  Your final output MUST be a single, valid JSON object: `{"constraints": [["concept_name", "value"], ...]}`.
        4.  The `concept_name` MUST be one of the simple keys from the definition (e.g., `diagram_type`, `label`, `shape`).
        5.  If no constraints can be extracted, the value of "constraints" MUST be an empty list `[]`.

        ---
        **CONCEPT DEFINITION (EXPANDED)**

        *   **High-Level Concepts:** `topic`, `purpose`, `target_audience`, `complexity_level`, `domain`, `visual_format`, `diagram_type`, `background_color`, `font_family`, `layout`, `style_theme`
        *   **Component-Level Concepts:**
            *   `component_label`: The text label inside a component.
            *   `component_shape`: The shape of a component (e.g., 'rectangle').
            *   `component_fill_color`: The fill color of a component.
            *   `component_border_style`: The border style of a component.
            *   `text_color`: The default text color used inside components.
        *   **Connection-Level Concepts:**
            *   `connection_label`: The text label on a connection.
            *   `connection_line_style`: The style of a line (e.g., 'solid').
            *   `connection_arrowhead`: The type of arrowhead.
        *   **Meta-Concepts:** `style_template`
        ---
        **EXAMPLES**

        User's Answer: "The main components are an Encoder block and a Decoder block."
        Your Output:
        {"constraints": [["component_label", "Encoder block"], ["component_label", "Decoder block"]]}

        User's Answer: "The connection between them should be labeled 'context vector'."
        Your Output:
        {"constraints": [["connection_label", "context vector"]]}

        User's Answer: "The diagram is a flowchart with rounded rectangle shapes."
        Your Output:
        {"constraints": [["diagram_type", "flowchart"], ["component_shape", "rounded_rectangle"]]}
        ---
        Now, analyze the following user's answer and provide the JSON output.

        **User's Answer:**
        "{user_answer_text}"
        """

    def _llm_parse(self, answer: str):
        prompt_template = self._get_llm_prompt_template()

        final_prompt = prompt_template.replace("{user_answer_text}", answer)
        
        try:
            completion = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": final_prompt}], 
                temperature=0.0,
                response_format={"type": "json_object"}
            )

            if self.token_tracker:
                self.token_tracker.add_from_response(self.model_name, completion)

            raw_content = completion.choices[0].message.content
            print(f"{raw_content}",flush=True)
            parsed_json = json.loads(raw_content)

            if isinstance(parsed_json, dict) and "constraints" in parsed_json:
                constraints_list = parsed_json["constraints"]
                if isinstance(constraints_list, list) and all(isinstance(item, list) and len(item) == 2 for item in constraints_list):
                    return constraints_list
            
            print(f"Warning: LLM parser returned malformed or unhandled data format: {raw_content}",flush=True)
            return []

        except json.JSONDecodeError:
            print(f"Error: Failed to decode JSON from LLM parser. Response: {raw_content}",flush=True)
            return []
        except Exception as e:
            print(f"Error during LLM parsing for answer '{answer}': {e}",flush=True)
            return []

class DialogueSimulator:
    def __init__(self, world_model, parser):
        self.world_model = world_model
        self.parser = parser
        self.api_configs = Config.API_CONFIGS
        self.answerer_config = self.api_configs[Config.ANSWERER_MODEL_KEY]
        self.full_questioner_pool = {k: self.api_configs[k] for k in Config.QUESTIONER_MODEL_KEYS}
        self.token_tracker = TokenTracker(Config.MODEL_PRICES)
        if hasattr(self.parser, "set_token_tracker"):
            self.parser.set_token_tracker(self.token_tracker)
        print(f"DialogueSimulator initialized with a pool of {len(self.full_questioner_pool)} questioner models.")



    def _track_if_paid(self, model_key: str, resp):
        if model_key == "Qwen3-235B":
            return
        self.token_tracker.add_from_response(model_key, resp)

    def _generate_candidate_questions(self, dialogue_history, confirmed_attributes, active_questioners):

        candidate_questions = []
        
        unconfirmed_attrs = [attr for attr in self.world_model.attribute_entropies.keys() if attr not in confirmed_attributes]
        unconfirmed_attrs.sort(key=lambda attr: self.world_model.get_attribute_entropy(attr), reverse=True)
        
        hint = ""
        if random.random() < 0.25 and unconfirmed_attrs:
            hint_attr = random.choice(unconfirmed_attrs[:5])
            hint = f"\n**Hint: The user has not specified details about `{hint_attr}` yet. Try asking a question related to this aspect.**"
            print(f"INFO: Providing a hint to ask about high-entropy attribute: '{hint_attr}'",flush=True)     
        system_prompt = f"""You are a 'Diagram Architect,' a specialist in Socratic questioning for designing scientific diagrams. Your objective is to guide a user to articulate their vision by asking a sequence of precise, high-impact questions.

        **Your Methodology: The Diagram Specification Funnel**
        You must follow this hierarchical approach, moving from general to specific. Do not jump to a lower level if a higher level is still ambiguous.

        *   **Level 1: Core Identity (The 'What')**
            *   Goal: Understand the diagram's fundamental purpose.
            *   Attributes: `topic`, `purpose`, `target_audience`, `domain`.
            *   Example Question: "What is the primary scientific concept this diagram aims to illustrate?"

        *   **Level 2: Macro-Structure (The 'How, Broadly')**
            *   Goal: Define the overall visual structure and main building blocks.
            *   Attributes: `diagram_type`, `visual_format`, `layout_grid`, main `components` and their types.
            *   Example Question: "What type of diagram is this? For example, a flowchart, a system architecture, or a schematic?"

        *   **Level 3: Micro-Details (The 'Specifics')**
            *   Goal: Flesh out the details of components and connections.
            *   Attributes: `style_theme`, `colors`, `border_styles`, `arrowheads`, `font_weights`.
            *   Example Question: "Should the arrows connecting the components be solid or dashed?"

        **Strict Rules of Engagement:**
        1.  **One Question at a Time:** Your entire output must be a single, focused question.
        2.  **Analyze History:** Carefully review the "Conversation History" to understand what is already known. Do not ask for information that has already been provided.
        3.  **Use the Hint:** If a `{hint}` is provided, it points to a high-value, unknown area. You should strongly consider formulating a question about it.
        4.  **No Fluff:** Do not use greetings, apologies, or any conversational filler. Be direct and professional.

        Based on the conversation history and any provided hint, formulate your next single, best question.
        {hint}
        """        
        
        AZURE_N2_NO_TEMPMAX = {"o3-mini", "o4-mini", "gpt-5", "o3"}  
        for model_key, config in list(active_questioners.items()):
            try:
                n_value = Config.NUM_QUESTIONS_PER_MODEL 

                msg_block = [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": f"Here is the conversation history:\n\n{dialogue_history}\n\nWhat is your next question?"}
                ]
                kwargs = {
                    "model": config["model_name"],
                    "messages": msg_block,
                    "n": n_value,
                }
                if model_key not in AZURE_N2_NO_TEMPMAX:
                    kwargs["temperature"] = 0.7
                    kwargs["max_tokens"] = 80

                completion = config["client"].chat.completions.create(**kwargs)

                self._track_if_paid(model_key, completion)

                for choice in completion.choices:
                    candidate_questions.append(choice.message.content.strip())
            except Exception as e:
                print(f"Warning: Failed to get questions from {model_key}. Error: {e}", flush=True)
                del active_questioners[model_key]


        n_value = Config.NUM_QUESTIONS_PER_MODEL

        add_style_proposal = False
        style_proposal_text = None
        if random.random() < Config.TEMPLATE_PROPOSAL_PROBABILITY:
            template_key, template_data = random.choice(list(Config.STYLE_TEMPLATES.items()))
            style_proposal_text = (
                f"To speed things up, would you like to use {template_data['description']}? "
                f"This would set a consistent look and feel for the diagram."
            )
            add_style_proposal = True
            print(f"INFO: Proposing style template: '{template_key}'", flush=True)

        candidate_questions = list(dict.fromkeys(candidate_questions))

        BASE_MAX = 8
        candidate_questions = candidate_questions[:BASE_MAX]

        if add_style_proposal:

            if (style_proposal_text not in candidate_questions) and (len(candidate_questions) < BASE_MAX + 1):
                candidate_questions.append(style_proposal_text)

        return candidate_questions


    def _get_answer_from_oracle(self, question, golden_json, dialogue_history):
        system_prompt = """You are acting as a user who has a complete scientific diagram in mind. Your complete knowledge of the diagram is provided in the following JSON object. Your task is to answer the assistant's questions with specific, concrete details derived ONLY from this JSON data.
        Answers should be simple and clear, with no more than 50 words.
        **Your Answering Principles (CRITICAL):**
        1.  **BE FACTUAL AND SPECIFIC:** Your primary goal is to provide concrete information from the JSON.
            - If asked "What is the overall structure?", and the JSON says `"diagram_type": "flowchart"`, you MUST answer with something like "It is a flowchart." or "The diagram is a flowchart.".
            - If asked "What shape are the components?", and the JSON has components with `"shape": "rounded_rectangle"`, you MUST answer "The main components are rounded rectangles."
            - **DO NOT** give vague, generic, or evasive answers like "It depends", "What do you think?", or "That's a good question." This is a simulation, and your role is to provide the facts from the JSON.
        2.  **BE CONCISE:** Answer the question directly. Do not add conversational fluff.
        3.  **HANDLE SUGGESTIONS:** If the assistant proposes a style template (e.g., "professional blue style"), check if its attributes align with the JSON. If they do, agree enthusiastically (e.g., "Yes, that professional blue style sounds perfect!").
        4.  **HANDLE UNANSWERABLE QUESTIONS:** If the question asks for information not present in the JSON, state that clearly (e.g., "That detail is not specified in my current design.").
        Here is the complete diagram information (Ground Truth JSON):
        """
        try:
            completion = self.answerer_config['client'].chat.completions.create(
                model=self.answerer_config['model_name'],
                messages=[
                    {"role": "system", "content": f"{system_prompt}\n{json.dumps(golden_json, indent=2)}"},
                    {"role": "user", "content": f"Here is our conversation so far:\n{dialogue_history}\n\nNow, please answer this question: {question}"}
                ],
                temperature=0.1,
                max_tokens=100
            )
            print(completion.choices[0].message.content.strip(),flush=True)
            return completion.choices[0].message.content.strip()
        except Exception as e:
            print(f"Warning: Answerer model failed. Error: {e}",flush=True)
            return "I'm sorry, I can't answer that right now."

    def run_simulation_for_file(self, golden_json_path):

        try:
            with open(golden_json_path, 'r', encoding='utf-8') as f:
                golden_json = json.load(f)
        except Exception as e:
            print(f"Error loading golden JSON {golden_json_path}: {e}",flush=True)
            return [], [], None
        num_to_sample = min(Config.NUM_SAMPLED_QUESTIONERS, len(self.full_questioner_pool))
        sampled_keys = random.sample(list(self.full_questioner_pool.keys()), num_to_sample)
        active_questioners = {key: self.full_questioner_pool[key] for key in sampled_keys}
        print(f"--- Starting simulation for {os.path.basename(golden_json_path)} with questioners: {list(active_questioners.keys())} ---", flush=True)

        dialogue_history = "User: I want to create a scientific diagram."
        confirmed_attributes = set()
        grpo_data_for_this_dialogue = []

        grpo_ablation_data_for_this_dialogue = []       

        remaining_entropy = self.world_model.total_initial_entropy

        for turn in range(Config.MAX_TURNS_PER_DIALOGUE):
            print(f"\n--- Turn {turn + 1}/{Config.MAX_TURNS_PER_DIALOGUE} for {os.path.basename(golden_json_path)} ---",flush=True)
            
            candidate_questions = self._generate_candidate_questions(dialogue_history, confirmed_attributes, active_questioners)            
            if not candidate_questions:
                print("No candidate questions generated. Ending dialogue.",flush=True)
                break

            question_rewards_data = []
            for question in candidate_questions:

                if not question or not question.strip():
                    print("[Filter] Empty question skipped.", flush=True)
                    continue
                answer = self._get_answer_from_oracle(question, golden_json, dialogue_history)

                if not answer or not answer.strip():
                    print(f"[Filter] Empty answer for question skipped: {question!r}", flush=True)
                    continue

                constraints = self.parser.parse(answer)
                
                reward_entropy, _ = self._calculate_reward(constraints, confirmed_attributes)
                
                reward_counting, _ = self._calculate_reward_counting(constraints, confirmed_attributes)
                
                question_rewards_data.append({
                    "question": question, 
                    "answer": answer, 
                    "reward_entropy": reward_entropy, 
                    "reward_counting": reward_counting,
                    "constraints": constraints
                })

            K = len(question_rewards_data)

            ranked_list_entropy = sorted(question_rewards_data, key=lambda x: x['reward_entropy'], reverse=True)[:K]

            if len(ranked_list_entropy) > 1:
                grpo_point_entropy = {"prompt": dialogue_history, "responses": [qr['question'] for qr in ranked_list_entropy],"reward": [float(qr['reward_entropy']) for qr in ranked_list_entropy]}
                grpo_data_for_this_dialogue.append(grpo_point_entropy)
                print(f"Generated ENTROPY-based GRPO point with {len(ranked_list_entropy)} responses.",flush=True)
            
            ranked_list_counting = sorted(question_rewards_data, key=lambda x: x['reward_counting'], reverse=True)[:K]

            if len(ranked_list_counting) > 1:
                grpo_point_counting = {"prompt": dialogue_history, "responses": [qr['question'] for qr in ranked_list_counting],"reward": [float(qr['reward_counting']) for qr in ranked_list_counting]}
                grpo_ablation_data_for_this_dialogue.append(grpo_point_counting)
                print(f"Generated COUNTING-based GRPO point with {len(ranked_list_counting)} responses.",flush=True)
            
            if not ranked_list_entropy:
                print("No valid Q/A this turn. Ending dialogue.", flush=True)
                break
            best_choice = ranked_list_entropy[0]
            
            dialogue_history += f"\nAssistant: {best_choice['question']}\nUser: {best_choice['answer']}"
            
            _, newly_confirmed = self._calculate_reward(best_choice['constraints'], confirmed_attributes)
            confirmed_attributes.update(newly_confirmed)
            
            remaining_entropy -= best_choice['reward_entropy'] 
            
            print(f"Best question: '{best_choice['question']}' | Entropy Reward: {best_choice['reward_entropy']:.3f} | Remaining Entropy: {remaining_entropy:.2f}",flush=True)

            if remaining_entropy < self.world_model.total_initial_entropy * Config.ENTROPY_THRESHOLD_RATIO:
                print("Entropy threshold reached. Ending dialogue.",flush=True)
                break

        print("INFO: Translating final JSON to natural language description for SFT data...")
        natural_language_completion = self._translate_json_to_description(golden_json)

        totals = self.token_tracker.snapshot_and_reset()
        print("\n[Token Usage + Cost for this sample]")
        for k, v in totals.items():
            print(f"- {k}: prompt={v['prompt']}, completion={v['completion']}, total={v['total']}, cost=${v['cost']:.4f}")
        print("")

        sft_prompt = f"""You are an expert at summarizing conversational information into a detailed, descriptive paragraph. Based on the following dialogue, provide a comprehensive natural language description of the scientific diagram. The description should be clear enough for an artist to recreate the image accurately.
        --- DIALOGUE HISTORY ---
        {dialogue_history}
        --- END OF DIALOGUE ---
        Now, provide a detailed, natural language description of the final diagram."""
        
        sft_point = { 
            "metadata": {"source_file": golden_json_path},
            "prompt": sft_prompt,
            "completion": natural_language_completion }
        
        return grpo_data_for_this_dialogue, grpo_ablation_data_for_this_dialogue, sft_point

    def _calculate_reward(self, constraints, confirmed_attributes):
        reward = 0.0
        newly_confirmed = set()
        
        for simple_attr, value in constraints:
            full_path_attr = Config.SCHEMA_MAPPER.get(simple_attr)

            if full_path_attr:
                if full_path_attr not in confirmed_attributes:
                    if full_path_attr == 'style_template':
                        template_attrs_map = Config.STYLE_TEMPLATES.get(value, {}).get("attributes", {})
                        for template_attr_key, _ in template_attrs_map.items():
                            full_template_attr_path = Config.SCHEMA_MAPPER.get(template_attr_key)
                            if full_template_attr_path and full_template_attr_path not in confirmed_attributes:
                                reward += self.world_model.get_attribute_entropy(full_template_attr_path)
                                newly_confirmed.add(full_template_attr_path)
                    else:
                        entropy_gain = self.world_model.get_attribute_entropy(full_path_attr)
                        reward += entropy_gain
                        newly_confirmed.add(full_path_attr)
            else:
                print(f"Warning: Mapper could not find a path for simple attribute '{simple_attr}'.",flush=True)
                
        return reward, newly_confirmed



    def _calculate_reward_counting(self, constraints: list, confirmed_attributes: set):

        if not constraints:
            return 0.0, set()
        
        newly_confirmed = set()
        for simple_attr, _ in constraints:
            full_path_attr = Config.SCHEMA_MAPPER.get(simple_attr)
            if full_path_attr and full_path_attr not in confirmed_attributes:

                newly_confirmed.add(full_path_attr)
        
        reward = float(len(newly_confirmed))
        return reward, newly_confirmed


    def _translate_json_to_description(self, golden_json):

        system_prompt = """You are a technical writer specializing in describing complex diagrams. Your task is to convert the following structured JSON object into a clear, detailed, and vivid natural language description. The description should be comprehensive enough for a graphic designer to recreate the diagram without seeing it. Describe the overall layout, style, colors, components, their text, and how they are connected.
        Here is the JSON data to be translated:
        """
        client = self.answerer_config['client']
        model_name = self.answerer_config['model_name']

        try:
            completion = client.chat.completions.create(
                model=model_name,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": json.dumps(golden_json, indent=2)}
                ],
                temperature=0.3,
                max_tokens=4096 
            )
            return completion.choices[0].message.content.strip()
        except Exception as e:
            print(f"Warning: JSON to text translation failed. Error: {e}",flush=True)
            return "A detailed description of the scientific diagram."

worker_simulator = None
worker_args = None

def init_worker(args):

    global worker_simulator, worker_args
    print(f"Initializing worker process {os.getpid()}...",flush=True)
    
    worker_args = args
    
    world_model = WorldModel(Config.WORLD_MODEL_FILE)
    
    llm_parser_config = {
        "client": Config.API_CONFIGS[Config.LLM_PARSER_MODEL_KEY]['client'],
        "model_name": Config.API_CONFIGS[Config.LLM_PARSER_MODEL_KEY]['model_name']
    }
    parser = SemanticParser(mode=worker_args.parser_mode, llm_parser_config=llm_parser_config)
    
    worker_simulator = DialogueSimulator(world_model, parser)

def process_single_file(filepath):

    if worker_simulator is None:

        print(f"Error: Worker {os.getpid()} not initialized. Skipping {os.path.basename(filepath)}.",flush=True)
        return [], [], None
        

    grpo_points_entropy, grpo_points_counting, sft_point = worker_simulator.run_simulation_for_file(filepath)
    return grpo_points_entropy, grpo_points_counting, sft_point


def main(args):

    def load_existing_dataset(filepath):
        if os.path.exists(filepath):
            try:
                with open(filepath, 'r', encoding='utf-8') as f:
                    print(f"Found existing dataset at '{filepath}'. Loading...", flush=True)
                    return json.load(f)
            except json.JSONDecodeError:
                print(f"Warning: Could not parse existing dataset at '{filepath}'. Starting fresh.", flush=True)
                return []
        return []

    master_grpo_dataset = load_existing_dataset(Config.OUTPUT_GRPO_FILE)
    master_grpo_ablation_dataset = load_existing_dataset(Config.OUTPUT_GRPO_ABLATION_FILE)
    master_sft_dataset = load_existing_dataset(Config.OUTPUT_SFT_FILE)

    processed_files = set()
    for sft_item in master_sft_dataset:
        md = sft_item.get("metadata", {})
        src = md.get("source_file")
        if src:
            processed_files.add(os.path.basename(src))
    print(f"Found {len(processed_files)} already processed files. They will be skipped.", flush=True)

    if not os.path.isdir(Config.GOLDEN_JSON_FOLDER):
        print(f"[Error] GOLDEN_JSON_FOLDER not found: {Config.GOLDEN_JSON_FOLDER}", flush=True)
        return

    all_json_files = [
        os.path.join(Config.GOLDEN_JSON_FOLDER, f)
        for f in os.listdir(Config.GOLDEN_JSON_FOLDER)
        if f.endswith('.json')
    ]
    files_to_process = [f for f in all_json_files if os.path.basename(f) not in processed_files]

    if not files_to_process:
        print("All files have already been processed. Exiting.", flush=True)
        return

    print(f"Total files to process: {len(files_to_process)} out of {len(all_json_files)}.", flush=True)

    new_grpo_count = 0
    new_ablation_count = 0
    new_sft_count = 0

    def save_all():
        def atomic_dump(path, data):
            tmp = path + ".tmp"
            with open(tmp, 'w', encoding='utf-8') as f:
                json.dump(data, f, indent=2, ensure_ascii=False)
            os.replace(tmp, path)

        atomic_dump(Config.OUTPUT_GRPO_FILE, master_grpo_dataset)
        atomic_dump(Config.OUTPUT_GRPO_ABLATION_FILE, master_grpo_ablation_dataset)
        atomic_dump(Config.OUTPUT_SFT_FILE, master_sft_dataset)

    processed_since_last_save = 0

    with concurrent.futures.ProcessPoolExecutor(
        max_workers=args.num_workers,
        initializer=init_worker,
        initargs=(args,)
    ) as executor:
        future_to_fp = {
            executor.submit(process_single_file, fp): fp
            for fp in files_to_process
        }

        for fut in tqdm(as_completed(future_to_fp), total=len(future_to_fp),
                        desc="Running Parallel Simulations"):
            fp = future_to_fp[fut]
            try:
                grpo_points_entropy, grpo_points_counting, sft_point = fut.result()
            except Exception as e:
                print(f"[Error] File failed: {os.path.basename(fp)} | {e}", flush=True)
                continue

            if grpo_points_entropy:
                master_grpo_dataset.extend(grpo_points_entropy)
                new_grpo_count += len(grpo_points_entropy)
            if grpo_points_counting:
                master_grpo_ablation_dataset.extend(grpo_points_counting)
                new_ablation_count += len(grpo_points_counting)
            if sft_point:
                master_sft_dataset.append(sft_point)
                new_sft_count += 1

            processed_since_last_save += 1

            if processed_since_last_save >= args.save_every:
                save_all()
                processed_since_last_save = 0

    save_all()

    print(f"Added {new_grpo_count} new entropy GRPO points, "
          f"{new_ablation_count} new counting GRPO points, "
          f"{new_sft_count} new SFT points.", flush=True)

    print("\nAggregation complete. Final dataset sizes:")
    print(f"Total ENTROPY-based GRPO data points: {len(master_grpo_dataset)}")
    print(f"Total COUNTING-based GRPO data points: {len(master_grpo_ablation_dataset)}")
    print(f"Total SFT data points: {len(master_sft_dataset)}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate GRPO and SFT datasets through parallel dialogue simulation.")
    parser.add_argument(
        '--parser_mode', 
        type=str, 
        choices=['llm', 'rule'], 
        default='llm',
        help="Choose the semantic parser mode: 'llm' (default) or 'rule'."
    )
    parser.add_argument(
        '--save_every',
        type=int,
        default=5,  
        help="Flush aggregated datasets to disk every K processed files."
    )
    parser.add_argument(
        '--num_workers', 
        type=int, 
        default=4,
        help="Number of parallel worker processes to use."
    )
    args = parser.parse_args()
    
    if os.name != 'posix':
        concurrent.futures.process.set_start_method('spawn', force=True)

    main(args)