import os
import json
import math
import collections
from typing import Any, List, Dict, Set, Tuple
from collections import defaultdict

try:
    from openai import AzureOpenAI
except Exception:
    AzureOpenAI = None

class RewardConfig:

    API_KEY     = os.getenv("AZURE_OPENAI_API_KEY", "")
    BASE_URL    = os.getenv("AZURE_OPENAI_ENDPOINT", "")
    API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION", "")

    ANSWERER_MODEL_NAME = os.getenv("ANSWERER_MODEL_NAME", "")
    PARSER_MODEL_NAME   = os.getenv("PARSER_MODEL_NAME", "")


    WORLD_MODEL_PATH = os.getenv(
        "WORLD_MODEL_JSON",
        "world_model_v3.json"
    )


    SCHEMA_MAPPER = {
        # Global Properties
        "topic": "global_properties.topic",
        "subject": "global_properties.topic", 
        "concept": "global_properties.topic", 
        "purpose": "global_properties.purpose",
        "goal": "global_properties.purpose", 
        "target_audience": "global_properties.target_audience",
        "audience": "global_properties.target_audience", 
        "complexity_level": "global_properties.complexity_level",
        "detail_level": "global_properties.complexity_level", 
        "background_color": "global_properties.background_color",
        "font_family": "global_properties.font_family",
        "title_text": "global_properties.title.text",
        "title": "global_properties.title.text",
        "diagram_title": "global_properties.title.text",
        "main_title": "global_properties.title.text",       
        "domain": "global_properties.domain",
        "field": "global_properties.domain", 
        "visual_format": "global_properties.visual_format",
        "diagram_type": "global_properties.diagram_type",
        "type": "global_properties.diagram_type", 
        "layout_grid": "global_properties.layout_grid",
        "layout": "global_properties.layout_grid", 
        "style_theme": "global_properties.style_theme",
        "theme": "global_properties.style_theme",
        "title_is_present": "global_properties.title.is_present",
        "has_title":        "global_properties.title.is_present",
        "show_title":       "global_properties.title.is_present",
        "global_title":     "global_properties.title.text",
        "figure_title":     "global_properties.title.text",
        
        # Component Properties (Generic)
        "component_type": "component.type",
        "component_shape": "component.geometry.shape",
        "shape": "component.geometry.shape", 

        "component_fill_color": "component.styling.fill_color",
        "fill_color": "component.styling.fill_color", 
        "component_border_color": "component.styling.border_color",
        "border_color": "component.styling.border_color", 
        "component_text_color": "component.text_properties.text_color",
        "text_color": "component.text_properties.text_color", 

        "component_border_style": "component.styling.border_style",
        "border_style": "component.styling.border_style", 
        "component_font_weight": "component.text_properties.font_weight",
        "font_weight": "component.text_properties.font_weight", 
        "component_label": "component.label",
        
        # Connection Properties (Generic)
        "connection_line_style": "connection.line_properties.style",
        "line_style": "connection.line_properties.style", 
        "connection_arrowhead_end": "connection.arrowhead.end_type",
        "arrowhead": "connection.arrowhead.end_type", 
        "connection_arrowhead": "connection.arrowhead.end_type",
        "connection_label": "connection.label.text",
        "connection_from_id": "connection.from_id",
        "from_id": "connection.from_id",
        "source_id": "connection.from_id",
        "connection_label_position": "connection.label.position",
        "connection_label_color": "connection.label.text_color",
        "connection_line_type": "connection.line_properties.type",
        "connection_line_color": "connection.line_properties.color",
        "connection_line_width": "connection.line_properties.width",
        "connection_arrowhead_start": "connection.arrowhead.start_type",
        "connection_arrowhead_size": "connection.arrowhead.size",
        "connection_to_id": "connection.to_id",
        "to_id": "connection.to_id",
        "target_id": "connection.to_id",
        "connection_from_id": "connection.from_id",
        "from_id":            "connection.from_id",
        "source_id":          "connection.from_id",
        "start_id":           "connection.from_id",
        "connection_to_id":   "connection.to_id",
        "to_id":              "connection.to_id",
        "target_id":          "connection.to_id",
        "end_id":             "connection.to_id",

        "layout_constraint_type": "layout_constraint.type",
        "layout_alignment": "layout_constraint.alignment_type",
        "layout_padding": "layout_constraint.padding",
        "layout_distribution": "layout_constraint.distribution_type",
        "layout_arrangement": "layout_constraint.arrangement",

        # Meta Properties
        "style_template": "style_template", 
    }

    STYLE_TEMPLATES = {
        "professional_blue": {
            "description": "a professional blue and white style, suitable for corporate or academic presentations",
            "attributes": {
                "style_theme": "professional_light",
                "background_color": "#FFFFFF",
                "component_fill_color": "#EAF1FD",
                "component_border_color": "#B4C7E7",
                "font_family": "Helvetica, Arial, sans-serif",
                "text_color": "#000000"
            }
        },
        "academic_grayscale": {
            "description": "a minimalist grayscale theme, perfect for academic papers and formal publications",
            "attributes": {
                "style_theme": "minimalist_grayscale",
                "background_color": "#FFFFFF",
                "component_fill_color": "#F0F0F0",
                "component_border_color": "#666666",
                "font_family": "Times New Roman, serif",
                "text_color": "#000000"
            }
        },
        "vibrant_tech": {
            "description": "a modern and vibrant dark mode style, great for tech startup presentations",
            "attributes": {
                "style_theme": "dark_mode_vibrant",
                "background_color": "#1A1A1A",
                "component_fill_color": "#2D2D2D",
                "component_border_color": "#00BFFF",
                "font_family": "Roboto, sans-serif",
                "text_color": "#EAEAEA"
            }
        },
        "blueprint_schematic": {
            "description": "a technical blueprint style with a dark blue background, ideal for engineering schematics or software architecture diagrams",
            "attributes": {
                "style_theme": "technical_blueprint",
                "background_color": "#0A2342",
                "component_fill_color": "#183454",
                "component_border_color": "#00FFFF",
                "font_family": "Courier New, monospace",
                "text_color": "#F0F0F0"
            }
        },
        "warm_organic": {
            "description": "a warm and inviting theme with earthy tones, suitable for biology, environmental science, or a softer presentation feel",
            "attributes": {
                "style_theme": "natural_warm",
                "background_color": "#FFF8E7",
                "component_fill_color": "#E8F5E9",
                "component_border_color": "#8D6E63",
                "font_family": "Verdana, Geneva, sans-serif",
                "text_color": "#4E342E"
            }
        }
    }


class _CountingRewardEngine:
    def __init__(self):
        print("Initializing _CountingRewardEngine (Singleton)...")
        self.answerer_client = self._create_api_client()
        self.parser_client   = self._create_api_client()
        self.dialogue_tracker: Dict[str, Set[str]] = collections.defaultdict(set)
        self.answerer_prompt_template = self._get_answerer_prompt_template()
        self.parser_prompt_template   = self._get_parser_prompt_template()
        print("_CountingRewardEngine initialized successfully.")

    def _create_api_client(self):
        if AzureOpenAI is None:
            raise RuntimeError("AzureOpenAI SDK not available. Please install/openai>=1.0 and configure Azure endpoint.")
        return AzureOpenAI(
            azure_endpoint=RewardConfig.BASE_URL,
            api_key=RewardConfig.API_KEY,
            api_version=RewardConfig.API_VERSION
        )

    def _get_answerer_prompt_template(self) -> str:
        return """You are acting as a user who has a complete scientific diagram in mind. Your complete knowledge of the diagram is provided in the following JSON object. Your task is to answer the assistant's questions with specific, concrete details derived ONLY from this JSON data.
Answers should be simple and clear, with no more than 50 words.
**Your Answering Principles (CRITICAL):**
1.  **BE FACTUAL AND SPECIFIC:** Use only facts from the JSON.
2.  **BE CONCISE:** Answer directly without fluff.
3.  **HANDLE SUGGESTIONS:** If a style template is proposed and consistent with the JSON, agree.
4.  **HANDLE UNANSWERABLE QUESTIONS:** If the JSON lacks the info, say so.
Here is the complete diagram information (Ground Truth JSON):
"""

    def _get_parser_prompt_template(self) -> str:
        return """
You are a highly precise linguistic analysis tool. Your *only* function is to extract structured constraints from a user's answer.

**Instructions:**
1. Analyze the "User's Answer".
2. Identify core concepts that match the "CONCEPT DEFINITION".
3. Output MUST be: {"constraints": [["concept_name", "value"], ...]} (a single JSON object).
4. concept_name MUST be one of: topic, purpose, target_audience, complexity_level, domain, visual_format, diagram_type, layout, style_theme,
   component_label, component_shape, component_fill_color, component_border_style,
   connection_label, connection_line_style, connection_arrowhead,
   style_template
5. If nothing can be extracted, return {"constraints": []}.

**EXAMPLES**
User's Answer: "The main components are an Encoder block and a Decoder block."
Your Output:
{"constraints": [["component_label", "Encoder block"], ["component_label", "Decoder block"]]}

User's Answer: "The connection between them should be labeled 'context vector'."
Your Output:
{"constraints": [["connection_label", "context vector"]]}

User's Answer: "The diagram is a flowchart with rounded rectangle shapes."
Your Output:
{"constraints": [["diagram_type", "flowchart"], ["component_shape", "rounded_rectangle"]]}

Now analyze:
**User's Answer:**
"{user_answer_text}"
"""

    def _get_simulated_answer(self, dialogue_history: str, question: str, golden_json: Dict) -> str:
        print(f"{question}",flush=True)
        sys_content = f"{self.answerer_prompt_template}\n{json.dumps(golden_json, indent=2)}"
        user_content = f"Here is our conversation so far:\n{dialogue_history}\n\nNow, please answer this question: {question}"
        try:
            completion = self.answerer_client.chat.completions.create(
                model=RewardConfig.ANSWERER_MODEL_NAME,
                messages=[{"role": "system", "content": sys_content},
                          {"role": "user",   "content": user_content}],
                temperature=0.1, max_tokens=100
            )
            answer=completion.choices[0].message.content.strip()
            print(f"{answer}",flush=True)
            return completion.choices[0].message.content.strip()
        except Exception as e:
            print(f"Warning: Answerer call failed for question '{question}': {e}")
            return ""

    def _parse_answer(self, answer: str) -> List[List[str]]:
        if not answer:
            return []
        prompt = self.parser_prompt_template.replace("{user_answer_text}", answer)
        try:
            completion = self.parser_client.chat.completions.create(
                model=RewardConfig.PARSER_MODEL_NAME,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0, response_format={"type": "json_object"}
            )
            raw = completion.choices[0].message.content
            print(f"{raw}", flush=True)
            obj = json.loads(raw)
            lst = obj.get("constraints", [])
            if isinstance(lst, list) and all(isinstance(x, list) and len(x) == 2 for x in lst):
                return [[str(x[0]), str(x[1])] for x in lst]
            print(f"Warning: Parser returned malformed data: {raw}")
            return []
        except Exception as e:
            print(f"Error during parsing: {e}")
            return []


    def _calculate_reward_counting(
        self,
        constraints: List[List[str]]
    ) -> Tuple[float, Set[str]]:

        mapper = RewardConfig.SCHEMA_MAPPER
        confirmed_now: Set[str] = set()

        for simple_attr, _ in constraints:
            full = mapper.get(simple_attr)
            if full and full != "style_template":
                confirmed_now.add(full)

        for simple_attr, value in constraints:
            if mapper.get(simple_attr) == "style_template":
                tpl = RewardConfig.STYLE_TEMPLATES.get(value, {})
                for k in (tpl.get("attributes") or {}).keys():
                    fp = mapper.get(k)
                    if fp:
                        confirmed_now.add(fp)
        reward=len(confirmed_now)
        print(reward)           
        return float(len(confirmed_now)), confirmed_now


    def process_batch(self, reward_inputs: List[Dict[str, Any]]) -> List[Dict[str, float]]:

        scores: List[Dict[str, float]] = []
        gt_cache: Dict[str, Dict] = {}

        for i, item in enumerate(reward_inputs):
            q = item.get("response", "")
            gt_path = item.get("ground_truth", "")
            if not q or not gt_path:
                print(f"Warning: Missing key in item {i}: {item}", flush=True)
                scores.append({"counting_reward": 0.0})
                continue

            golden_json = gt_cache.get(gt_path)
            if golden_json is None:
                try:
                    with open(gt_path, "r", encoding="utf-8") as f:
                        golden_json = json.load(f)
                    gt_cache[gt_path] = golden_json
                except Exception as e:
                    print(f"Error loading golden_json '{gt_path}': {e}. Assigning 0 reward.", flush=True)
                    scores.append({"counting_reward": 0.0})
                    continue

            answer = self._get_simulated_answer("", q, golden_json)
            constraints = self._parse_answer(answer)
            reward, _ = self._calculate_reward_counting(constraints)

            scores.append({"counting_reward": reward})

        return scores

_ENGINE_SINGLETON: _CountingRewardEngine = None

def get_reward_engine() -> _CountingRewardEngine:
    global _ENGINE_SINGLETON
    if _ENGINE_SINGLETON is None:
        _ENGINE_SINGLETON = _CountingRewardEngine()
    return _ENGINE_SINGLETON

def compute_score(reward_inputs: list[dict[str, Any]], **kwargs) -> list[dict[str, float]]:

    if not isinstance(reward_inputs, list):
        raise ValueError("This reward function expects `reward_inputs` to be a list.")

    engine = get_reward_engine()

    print(reward_inputs)
    raw_scores = engine.process_batch(reward_inputs)

    out = []
    for d in raw_scores:
        r = float(d.get("counting_reward", 0.0))
        out.append({"overall": r, "counting_reward": r})
    print(out,flush=True)
    return out


# from counting_reward_model import compute_score as counting_score
# scores = counting_score(batch_inputs)
