import json
import math
import collections
from typing import Any, List, Dict, Set
from openai import AzureOpenAI
from collections import defaultdict
from string import Template
import os


class RewardConfig:

    API_KEY     = os.getenv("AZURE_OPENAI_API_KEY", "")
    BASE_URL    = os.getenv("AZURE_OPENAI_ENDPOINT", "")
    API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION", "")

    ANSWERER_MODEL_NAME = os.getenv("ANSWERER_MODEL_NAME", "gpt-4o")
    PARSER_MODEL_NAME   = os.getenv("PARSER_MODEL_NAME", "gpt-4o")
    

    WORLD_MODEL_PATH = "world_model_v3.json"

    SCHEMA_MAPPER = {
        # Global Properties
        "topic": "global_properties.topic",
        "subject": "global_properties.topic", 
        "concept": "global_properties.topic", 
        "purpose": "global_properties.purpose",
        "goal": "global_properties.purpose", 
        "target_audience": "global_properties.target_audience",
        "audience": "global_properties.target_audience", 
        "complexity_level": "global_properties.complexity_level",
        "detail_level": "global_properties.complexity_level", 
        "background_color": "global_properties.background_color",
        "font_family": "global_properties.font_family",
        "title_text": "global_properties.title.text",
        "title": "global_properties.title.text",
        "diagram_title": "global_properties.title.text",
        "main_title": "global_properties.title.text",       
        "domain": "global_properties.domain",
        "field": "global_properties.domain", 
        "visual_format": "global_properties.visual_format",
        "diagram_type": "global_properties.diagram_type",
        "type": "global_properties.diagram_type", 
        "layout_grid": "global_properties.layout_grid",
        "layout": "global_properties.layout_grid", 
        "style_theme": "global_properties.style_theme",
        "theme": "global_properties.style_theme",
        "title_is_present": "global_properties.title.is_present",
        "has_title":        "global_properties.title.is_present",
        "show_title":       "global_properties.title.is_present",
        "global_title":     "global_properties.title.text",
        "figure_title":     "global_properties.title.text",
        
        # Component Properties (Generic)
        "component_type": "component.type",
        "component_shape": "component.geometry.shape",
        "shape": "component.geometry.shape", 

        "component_fill_color": "component.styling.fill_color",
        "fill_color": "component.styling.fill_color", 
        "component_border_color": "component.styling.border_color",
        "border_color": "component.styling.border_color", 
        "component_text_color": "component.text_properties.text_color",
        "text_color": "component.text_properties.text_color", 

        "component_border_style": "component.styling.border_style",
        "border_style": "component.styling.border_style", 
        "component_font_weight": "component.text_properties.font_weight",
        "font_weight": "component.text_properties.font_weight", 
        "component_label": "component.label",
        
        # Connection Properties (Generic)
        "connection_line_style": "connection.line_properties.style",
        "line_style": "connection.line_properties.style", 
        "connection_arrowhead_end": "connection.arrowhead.end_type",
        "arrowhead": "connection.arrowhead.end_type", 
        "connection_arrowhead": "connection.arrowhead.end_type",
        "connection_label": "connection.label.text",
        "connection_from_id": "connection.from_id",
        "from_id": "connection.from_id",
        "source_id": "connection.from_id",
        "connection_label_position": "connection.label.position",
        "connection_label_color": "connection.label.text_color",
        "connection_line_type": "connection.line_properties.type",
        "connection_line_color": "connection.line_properties.color",
        "connection_line_width": "connection.line_properties.width",
        "connection_arrowhead_start": "connection.arrowhead.start_type",
        "connection_arrowhead_size": "connection.arrowhead.size",
        "connection_to_id": "connection.to_id",
        "to_id": "connection.to_id",
        "target_id": "connection.to_id",
        "connection_from_id": "connection.from_id",
        "from_id":            "connection.from_id",
        "source_id":          "connection.from_id",
        "start_id":           "connection.from_id",
        "connection_to_id":   "connection.to_id",
        "to_id":              "connection.to_id",
        "target_id":          "connection.to_id",
        "end_id":             "connection.to_id",

        "layout_constraint_type": "layout_constraint.type",
        "layout_alignment": "layout_constraint.alignment_type",
        "layout_padding": "layout_constraint.padding",
        "layout_distribution": "layout_constraint.distribution_type",
        "layout_arrangement": "layout_constraint.arrangement",

        # Meta Properties
        "style_template": "style_template", 
    }

    
    STYLE_TEMPLATES = {
        "professional_blue": {
            "description": "a professional blue and white style, suitable for corporate or academic presentations",
            "attributes": {
                "style_theme": "professional_light",
                "background_color": "#FFFFFF",
                "component_fill_color": "#EAF1FD",
                "component_border_color": "#B4C7E7",
                "font_family": "Helvetica, Arial, sans-serif",
                "text_color": "#000000"
            }
        },
        "academic_grayscale": {
            "description": "a minimalist grayscale theme, perfect for academic papers and formal publications",
            "attributes": {
                "style_theme": "minimalist_grayscale",
                "background_color": "#FFFFFF",
                "component_fill_color": "#F0F0F0",
                "component_border_color": "#666666",
                "font_family": "Times New Roman, serif",
                "text_color": "#000000"
            }
        },
        "vibrant_tech": {
            "description": "a modern and vibrant dark mode style, great for tech startup presentations",
            "attributes": {
                "style_theme": "dark_mode_vibrant",
                "background_color": "#1A1A1A",
                "component_fill_color": "#2D2D2D",
                "component_border_color": "#00BFFF", 
                "font_family": "Roboto, sans-serif",
                "text_color": "#EAEAEA"
            }
        },
                "blueprint_schematic": {
            "description": "a technical blueprint style with a dark blue background, ideal for engineering schematics or software architecture diagrams",
            "attributes": {
                "style_theme": "technical_blueprint",
                "background_color": "#0A2342", 
                "component_fill_color": "#183454", 
                "component_border_color": "#00FFFF", 
                "font_family": "Courier New, monospace", 
                "text_color": "#F0F0F0" 
            }
        },
        "warm_organic": {
            "description": "a warm and inviting theme with earthy tones, suitable for biology, environmental science, or a softer presentation feel",
            "attributes": {
                "style_theme": "natural_warm",
                "background_color": "#FFF8E7", 
                "component_fill_color": "#E8F5E9", 
                "component_border_color": "#8D6E63",
                "font_family": "Verdana, Geneva, sans-serif", 
                "text_color": "#4E342E"
            }
        }
    }



class _EntropyRewardEngine:
    def __init__(self):
        print("Initializing _EntropyRewardEngine (Singleton)...",flush=True)

        self.world_model = self._load_world_model()
        self.answerer_client = self._create_api_client()
        self.parser_client = self._create_api_client() 
        
        self.dialogue_tracker = collections.defaultdict(set)
        
        self.answerer_prompt_template = self._get_answerer_prompt_template()
        self.parser_prompt_template = self._get_parser_prompt_template()
        print("_EntropyRewardEngine initialized successfully.",flush=True)

    def _load_world_model(self) -> Dict[str, float]:
        path = RewardConfig.WORLD_MODEL_PATH
        print(f"RewardModule: Loading world model from {path}...",flush=True)
        try:
            with open(path, 'r', encoding='utf-8') as f:
                model_data = json.load(f)
            
            entropies = collections.defaultdict(float)
            distributions = model_data.get("prior_distributions", {})
            for attr, dist in distributions.items():
                entropy = -sum(p * math.log2(p) for p in dist.values() if p > 0)
                entropies[attr] = entropy
            print("RewardModule: World model loaded successfully.",flush=True)
            return entropies
        except Exception as e:
            print(f"FATAL ERROR loading World Model from '{path}': {e}",flush=True)
            raise

    def _create_api_client(self):
        return AzureOpenAI(
            azure_endpoint=RewardConfig.BASE_URL,
            api_key=RewardConfig.API_KEY,
            api_version=RewardConfig.API_VERSION
        )

    def _get_answerer_prompt_template(self) -> str:
        return """You are acting as a user who has a complete scientific diagram in mind. Your complete knowledge of the diagram is provided in the following JSON object. Your task is to answer the assistant's questions with specific, concrete details derived ONLY from this JSON data.
        Answers should be simple and clear, with no more than 50 words.
        **Your Answering Principles (CRITICAL):**
        1.  **BE FACTUAL AND SPECIFIC:** Your primary goal is to provide concrete information from the JSON.
            - If asked "What is the overall structure?", and the JSON says `"diagram_type": "flowchart"`, you MUST answer with something like "It is a flowchart." or "The diagram is a flowchart.".
            - If asked "What shape are the components?", and the JSON has components with `"shape": "rounded_rectangle"`, you MUST answer "The main components are rounded rectangles."
            - **DO NOT** give vague, generic, or evasive answers like "It depends", "What do you think?", or "That's a good question." This is a simulation, and your role is to provide the facts from the JSON.
        2.  **BE CONCISE:** Answer the question directly. Do not add conversational fluff.
        3.  **HANDLE SUGGESTIONS:** If the assistant proposes a style template (e.g., "professional blue style"), check if its attributes align with the JSON. If they do, agree enthusiastically (e.g., "Yes, that professional blue style sounds perfect!").
        4.  **HANDLE UNANSWERABLE QUESTIONS:** If the question asks for information not present in the JSON, state that clearly (e.g., "That detail is not specified in my current design.").
        Here is the complete diagram information (Ground Truth JSON):
        """

    def _get_parser_prompt_template(self) -> str:
        return """
        You are a highly precise linguistic analysis tool. Your *only* function is to extract structured constraints from a user's answer.
        **Instructions:**
        1.  Analyze the "User's Answer".
        2.  Identify any core concepts that match the "CONCEPT DEFINITION".
        3.  Your final output MUST be a single, valid JSON object: `{"constraints": [["concept_name", "value"], ...]}`.
        4.  The `concept_name` MUST be one of the simple keys from the definition (e.g., `diagram_type`, `label`, `shape`).
        5.  If no constraints can be extracted, the value of "constraints" MUST be an empty list `[]`.
        ---
        **CONCEPT DEFINITION (EXPANDED)**
        *   **High-Level Concepts:** `topic`, `purpose`, `target_audience`, `complexity_level`, `domain`, `visual_format`, `diagram_type`, `layout`, `style_theme`
        *   **Component-Level Concepts:**
            *   `component_label`: The text label inside a component.
            *   `component_shape`: The shape of a component (e.g., 'rectangle').
            *   `component_fill_color`: The fill color of a component.
            *   `component_border_style`: The border style of a component.
        *   **Connection-Level Concepts:**
            *   `connection_label`: The text label on a connection.
            *   `connection_line_style`: The style of a line (e.g., 'solid').
            *   `connection_arrowhead`: The type of arrowhead.
        *   **Meta-Concepts:** `style_template`
        ---
        **EXAMPLES**
        User's Answer: "The main components are an Encoder block and a Decoder block."
        Your Output:
        {"constraints": [["component_label", "Encoder block"], ["component_label", "Decoder block"]]}
        User's Answer: "The connection between them should be labeled 'context vector'."
        Your Output:
        {"constraints": [["connection_label", "context vector"]]}
        User's Answer: "The diagram is a flowchart with rounded rectangle shapes."
        Your Output:
        {"constraints": [["diagram_type", "flowchart"], ["component_shape", "rounded_rectangle"]]}
        ---
        Now, analyze the following user's answer and provide the JSON output.
        **User's Answer:**
        "{user_answer_text}"
        """

    def _get_simulated_answer(self, dialogue_history: str, question: str, golden_json: Dict) -> str:
        print(f"{question}",flush=True)
        system_content = f"{self.answerer_prompt_template}\n{json.dumps(golden_json, indent=2)}"
        user_content = f"Here is our conversation so far:\n{dialogue_history}\n\nNow, please answer this question: {question}"
        try:
            completion = self.answerer_client.chat.completions.create(
                model=RewardConfig.ANSWERER_MODEL_NAME,
                messages=[{"role": "system", "content": system_content}, {"role": "user", "content": user_content}],
                temperature=0.1, max_tokens=160
            )
            answer=completion.choices[0].message.content.strip()
            print(f"{answer}",flush=True)
            return completion.choices[0].message.content.strip()
        except Exception as e:
            print(f"Warning: Answerer model failed for question '{question}': {e}",flush=True)
            return ""

    def _parse_answer(self, answer: str) -> List[List[str]]:
        prompt = self.parser_prompt_template.replace("{user_answer_text}", answer)
        try:
            completion = self.parser_client.chat.completions.create(
                model=RewardConfig.PARSER_MODEL_NAME,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0, response_format={"type": "json_object"}
            )
            raw_content = completion.choices[0].message.content
            print(f"{raw_content}",flush=True)
            parsed_json = json.loads(raw_content)
            if isinstance(parsed_json, dict) and "constraints" in parsed_json:
                constraints_list = parsed_json["constraints"]
                if isinstance(constraints_list, list) and all(isinstance(item, list) and len(item) == 2 for item in constraints_list):
                    return [[str(item[0]), str(item[1])] for item in constraints_list]
            print(f"Warning: LLM parser returned malformed data: {raw_content}",flush=True)
            return []
        except Exception as e:
            print(f"Error during LLM parsing for answer '{answer}': {e}",flush=True)
            return []

    def _calculate_reward(self, constraints: List[List[str]], confirmed_attributes: Set[str]) -> (float, Set[str]):
        reward = 0.0
        newly_confirmed = set()
        for simple_attr, value in constraints:
            full_path_attr = RewardConfig.SCHEMA_MAPPER.get(simple_attr)
            if not full_path_attr:
                continue
            if full_path_attr not in confirmed_attributes:
                if full_path_attr == 'style_template':
                    template_attrs = RewardConfig.STYLE_TEMPLATES.get(value, {}).get("attributes", {})
                    for template_attr_key, _ in template_attrs.items():
                        full_template_attr_path = RewardConfig.SCHEMA_MAPPER.get(template_attr_key)
                        if full_template_attr_path and full_template_attr_path not in confirmed_attributes:
                            reward += self.world_model.get(full_template_attr_path, 0.0)
                            newly_confirmed.add(full_template_attr_path)
                else:
                    reward += self.world_model.get(full_path_attr, 0.0)
                    newly_confirmed.add(full_path_attr)         
        return reward, newly_confirmed

    def process_batch(self, reward_inputs: List[Dict[str, Any]]) -> List[Dict[str, float]]:
        scores: List[Dict[str, float]] = []
        gt_cache: Dict[str, Dict] = {}

        for i, item in enumerate(reward_inputs):
            if "response" not in item or "ground_truth" not in item:
                print(f"Warning: Missing key in item {i}: {item}", flush=True)
                scores.append({"entropy_reward": 0.0})
                continue

            question = item["response"]
            golden_json_path = item["ground_truth"]

            golden_json = gt_cache.get(golden_json_path)
            if golden_json is None:
                try:
                    with open(golden_json_path, "r", encoding="utf-8") as f:
                        golden_json = json.load(f)
                    gt_cache[golden_json_path] = golden_json
                except Exception as e:
                    print(f"Error loading golden_json '{golden_json_path}': {e}. Assigning 0 reward.", flush=True)
                    scores.append({"entropy_reward": 0.0})
                    continue

            answer = self._get_simulated_answer("", question, golden_json)
            constraints = self._parse_answer(answer)
            reward, _ = self._calculate_reward(constraints, confirmed_attributes=set())

            scores.append({"entropy_reward": reward})

        return scores


_REWARD_ENGINE_INSTANCE = None

def get_reward_engine():
    """Singleton factory for the reward engine."""
    global _REWARD_ENGINE_INSTANCE
    if _REWARD_ENGINE_INSTANCE is None:
        _REWARD_ENGINE_INSTANCE = _EntropyRewardEngine()
    return _REWARD_ENGINE_INSTANCE

def compute_score(reward_inputs: list[dict[str, Any]], **kwargs) -> list[dict[str, float]]:
    """
    This is the main function to be passed to the training framework.
    It receives a batch of prompts and responses, computes the entropy-based
    reward for each, and returns a list of scores.

    Args:
        reward_inputs (list[dict[str, Any]]): 
            A list where each dict contains:
            - "prompt": The dialogue history.
            - "response": The candidate question generated by the model.
            - "ground_truth": The path to the golden JSON file.
        **kwargs: Catches any other arguments the framework might pass.

    Returns:
        list[dict[str, float]]: 
            A list of score dictionaries, e.g., [{"entropy_reward": 0.85}, ...].
            The key "entropy_reward" can be customized if needed.
    """
    if not isinstance(reward_inputs, list):
        raise ValueError("This reward function expects `reward_inputs` to be a list.")
    
    engine = get_reward_engine()

    print(reward_inputs)
    scores = engine.process_batch(reward_inputs)

    final_scores = []
    for score_dict in scores:
        reward = score_dict.get("entropy_reward", 0.0)
        final_scores.append({
            "overall": reward,  
            "entropy_reward": reward 
        })
    print(final_scores,flush=True)
    return final_scores


# from entropy_reward_model import compute_score as entropy_score
# scores = entropy_score(batch_inputs)