import os
import json
import base64
from openai import OpenAI
from tqdm import tqdm
import collections
from openai import AzureOpenAI


API_KEY = "" 
BASE_URL = ""
MODEL_NAME = "" 

IMAGE_FOLDER = ""  
JSON_OUTPUT_FOLDER = ""  
WORLD_MODEL_FILE = "world_model_v3.json"  

def encode_image(image_path):
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except Exception as e:
        print(f"Error encoding image {image_path}: {e}",flush=True)
        return None

def get_generation_prompt_v2():
    return """
    You are a world-class expert in visual information systems and computational graphics. Your task is to meticulously analyze the provided scientific diagram and deconstruct it into a highly structured, comprehensive, and standardized JSON object. This JSON represents the diagram's "State Vector" and must strictly adhere to the schema and conventions detailed below.

    **OUTPUT MUST BE A SINGLE, VALID JSON OBJECT AND NOTHING ELSE.**

    ---
    **JSON SCHEMA DEFINITION (v2.0)**

    The JSON object must contain these top-level keys: `schema_version`, `metadata`, `global_properties`, `components`, `connections`, `layout_constraints`.

    1.  **`global_properties`**: Describes diagram-wide settings.
        *   `topic`: (String) The main subject, concept, or title of the diagram.
        *   `purpose`: (String) The primary goal of the diagram. Enum: 'illustration', 'comparison', 'data_flow', 'workflow', 'architecture_overview'.
        *   `target_audience`: (String) The intended audience. Enum: 'researcher', 'student', 'general_public', 'engineer'.
        *   `complexity_level`: (String) The level of detail. Enum: 'high_level_overview', 'detailed_schematic', 'publication_quality'.
        *   `domain`: (String) The scientific or technical field, e.g., 'Computer Vision', 'Biology'.
        *   `visual_format`: (String) The primary visual representation style. Enum: 'flowchart', 'block_diagram', 'comparison_layout', 'conceptual_map'.
        *   `diagram_type`: (String) Enum: 'flowchart', 'schematic', 'architecture_diagram', 'timeline'.
        *   `layout_grid`: (String) The overall grid structure, e.g., '1xN', '2x2', '3-tier_vertical', 'freeform'.
        *   `style_theme`: (String) e.g., 'professional_light', 'minimalist', 'corporate_blue'.
        *   `background_color`: (String) HEX color code, e.g., "#FFFFFF".
        *   `font_family`: (String) e.g., "Helvetica, Arial, sans-serif".
        *   `title`: (Object) Contains `text` (String) and `is_present` (boolean).

    2.  **`components`**: An array of all visual elements. Each element is an object with:
        *   `id`: (String) A unique identifier, e.g., "c1", "op1".
        *   `quantity`: (Integer, Optional) The number of similar components being represented by this node.
        *   `type`: (String) Enum: 'shape_node', 'text_node', 'icon_node', 'group_container', 'operator_node', 'custom_element'.
        *   `label`, `sub_label`: (String) The primary and secondary text.
        *   `geometry`: (Object) Contains `shape` (e.g., 'rounded_rectangle', 'circle').
        *   `styling`: (Object) Contains `fill_color` (HEX), `border_color` (HEX), `border_width` (pixels), `border_style` ('solid', 'dashed').
        *   `text_properties`: (Object) Contains `font_weight` ('normal', 'bold'), `text_color` (HEX).

    3.  **`connections`**: An array of all lines and arrows. **(Schema Upgraded)**
        *   `id`: (String) A unique identifier.
        *   `from_id`, `to_id`: (String) The IDs of the components it connects.
        *   `label`: (Object) Contains `text` (String), `position` ('start', 'middle_above', 'end'), `text_color` (HEX).
        *   `line_properties`: (Object) Contains `type` ('straight', 'curved_clockwise', 'orthogonal'), `style` ('solid', 'dashed'), `color` (HEX), `width` (pixels).
        *   `arrowhead`: (Object) Contains `start_type` and `end_type` ('none', 'solid_triangle'), `size` ('small', 'medium', 'large').

    4.  **`layout_constraints`**: An array describing the spatial relationships. **(Schema Upgraded)**
        *   `type`: (String) Enum: 'relative_arrangement', 'containment', 'alignment', 'distribution'.
        *   If `type` is 'alignment', use `alignment_type`: ('left_edge', 'horizontal_center', etc.).
        *   If `type` is 'distribution', use `distribution_type`: ('horizontal_equal_spacing', etc.).
        *   If `type` is 'containment', it can have an optional `padding` property: ('small', 'medium', 'large').
        *   `element_ids` or `container_id`: The relevant component IDs.

    ---
    **EXAMPLE OF A PERFECT OUTPUT:**
    ```json
    {
    "schema_version": "2.0",
    "metadata": {"diagram_id": "example_01"},
    "global_properties": {
        "diagram_type": "flowchart",
        "style_theme": "professional_light",
        "background_color": "#FFFFFF",
        "font_family": "sans-serif",
        "title": {"text": "Example Process", "is_present": true}
    },
    "components": [
        {
        "id": "c1",
        "type": "shape_node",
        "label": "Input Data",
        "geometry": { "shape": "rectangle" },
        "styling": { "fill_color": "#D9EAD3", "border_color": "#000000", "border_width": 1, "border_style": "solid" },
        "text_properties": { "font_weight": "normal", "text_color": "#000000" }
        }
    ],
      "connections": [
        {
          "id": "conn1",
          "from_id": "c1",
          "to_id": "c2",
          "label": { "text": "Query", "position": "middle_above" },
          "line_properties": { "type": "straight", "style": "solid", "color": "#000000", "width": 1.5 },
          "arrowhead": { "end_type": "solid_triangle", "size": "medium" }
        }
      ],
      "layout_constraints": [
        { "type": "alignment", "alignment_type": "left_edge", "element_ids": ["c4", "c5"] },
        { "type": "distribution", "distribution_type": "horizontal_equal_spacing", "element_ids": ["icon1", "icon2"] }
      ]
    }
    ```
    ---
    Now, analyze the user-provided image and generate its State Vector according to these precise instructions.
    """

def generate_state_vector_for_image(client, image_path):

    base64_image = encode_image(image_path)
    if not base64_image:
        return None
    messages = [
        {
            "role": "system",
            "content": get_generation_prompt_v2() 
        },
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Please analyze this scientific diagram and generate its State Vector in the specified JSON format (v2.0)."},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
            ]
        }
    ]
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=messages,
            temperature=0.0,
            response_format={"type": "json_object"}
        )
        response_text = completion.choices[0].message.content
        return json.loads(response_text)
    except Exception as e:

        print(f"API call with json_object mode failed or is not supported, trying manual parsing. Error: {e}",flush=True)
        try:
            completion = client.chat.completions.create(
                model=MODEL_NAME,
                messages=messages,
                temperature=0.0
            )
            response_text = completion.choices[0].message.content
            if response_text.strip().startswith("```json"):
                response_text = response_text.strip()[7:-3].strip()
            return json.loads(response_text)
        except Exception as final_e:
            print(f"Failed to process {os.path.basename(image_path)}: {final_e}",flush=True)
            return None

def build_world_model_v2_dynamic(json_folder):

    print("\n--- Building Dynamic World Model v2.0 ---",flush=True)
    
    counters = collections.defaultdict(lambda: collections.defaultdict(int))
    totals = collections.defaultdict(int)
    
    json_files = [f for f in os.listdir(json_folder) if f.endswith('.json')]
    if not json_files:
        print("No state vector JSON files found.",flush=True)
        return None

    for filename in tqdm(json_files, desc="Aggregating JSONs (v2.0)"):
        filepath = os.path.join(json_folder, filename)
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except (json.JSONDecodeError, UnicodeDecodeError):
            print(f"Warning: Skipping corrupted or invalid JSON file {filename}",flush=True)
            continue
            
        totals['diagrams'] += 1

        def process_node(node, path_prefix, total_key):
            if not isinstance(node, dict):
                return
            
            for key, value in node.items():

                if isinstance(value, (str, int, float, bool)):
                    attribute_name = f"{path_prefix}.{key}"
                    value_str = str(value) if isinstance(value, bool) else value
                    counters[attribute_name][value_str] += 1
                elif isinstance(value, dict):
                    process_node(value, f"{path_prefix}.{key}", total_key)

        if "global_properties" in data:
            process_node(data["global_properties"], "global_properties", "diagrams")
            
        if "components" in data and isinstance(data["components"], list):
            totals['components'] += len(data["components"])
            for comp in data["components"]:
                process_node(comp, "component", "components")

        if "connections" in data and isinstance(data["connections"], list):
            totals['connections'] += len(data["connections"])
            for conn in data["connections"]:
                process_node(conn, "connection", "connections")

        if "layout_constraints" in data and isinstance(data["layout_constraints"], list):
            totals['layout_constraints'] += len(data["layout_constraints"])
            for const in data["layout_constraints"]:
                process_node(const, "layout_constraint", "layout_constraints")

    probabilities = collections.defaultdict(dict)
    for attr, value_counts in counters.items():
        total_key = 'diagrams' 
        if attr.startswith('component.'): total_key = 'components'
        elif attr.startswith('connection.'): total_key = 'connections'
        elif attr.startswith('layout_constraint.'): total_key = 'layout_constraints'
        
        total = totals[total_key]
        if total > 0:
            num_categories = len(value_counts)
            for value, count in value_counts.items():
                if value is None: continue
                probabilities[attr][value] = (count + 1) / (total + num_categories)

    world_model = {
        "schema_version": "2.0_dynamic",
        "metadata": {
            "corpus_size": totals['diagrams'],
            "total_components": totals['components'],
            "total_connections": totals['connections'],
            "total_layout_constraints": totals['layout_constraints'],
            "smoothing_method": "laplace"
        },
        "prior_distributions": probabilities
    }
    
    print("Dynamic World Model v2.0 built successfully.",flush=True)
    return world_model


if __name__ == "__main__":
    print("--- Step 1: Generating State Vectors (v2.0) ---",flush=True)
    
    os.makedirs(IMAGE_FOLDER, exist_ok=True)
    os.makedirs(JSON_OUTPUT_FOLDER, exist_ok=True)
    
    client = AzureOpenAI(azure_endpoint=BASE_URL,api_key=API_KEY,api_version="")
    
    image_files = [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    if not image_files:
        print(f"Error: No images found in the '{IMAGE_FOLDER}' directory.",flush=True)
    else:
        for filename in tqdm(image_files, desc="Generating JSONs (v2.0)"):
            image_path = os.path.join(IMAGE_FOLDER, filename)
            json_filename = os.path.splitext(filename)[0] + ".json"
            json_path = os.path.join(JSON_OUTPUT_FOLDER, json_filename)

            if os.path.exists(json_path):
                print(f"Skipping {filename}, JSON already exists.",flush=True)
                continue

            state_vector = generate_state_vector_for_image(client, image_path)
            
            if state_vector:
                with open(json_path, 'w', encoding='utf-8') as f:
                    json.dump(state_vector, f, indent=2, ensure_ascii=False)
                print(f"Successfully generated state vector for {filename}",flush=True)
            else:
                print(f"Failed to generate state vector for {filename}",flush=True)

    print("\n--- Step 2: Building World Model from generated JSONs (v2.0) ---",flush=True)
    
    final_world_model = build_world_model_v2_dynamic(JSON_OUTPUT_FOLDER)
    
    if final_world_model:
        with open(WORLD_MODEL_FILE, 'w', encoding='utf-8') as f:
            json.dump(final_world_model, f, indent=2, ensure_ascii=False)
        print(f"\nWorld model v2.0 saved to '{WORLD_MODEL_FILE}'",flush=True)
        print("\nExample of learned attributes and their distributions:",flush=True)
        for i, (attr, dist) in enumerate(final_world_model['prior_distributions'].items()):
            if i >= 5: break
            print(f"- {attr}: {list(dist.keys())[:3]}...",flush=True)
