import os
import sys
import re
import argparse
import json
from tqdm import tqdm
from openai import OpenAI
from collections import OrderedDict

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from evaluation.evaluate import validate_jsonld


def load_text_file(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        return f.read()


def extract_synthesis_procedure(text_path, template_path, model, example_path=None):
    txt_text = load_text_file(text_path)
    template = load_text_file(template_path)
    if example_path:
        example_text = load_text_file(example_path)
        template = template.replace("<IN_CONTEXT_EXAMPLE>", example_text)
    else:
        template = re.sub(r"#\s*Example\s*<IN_CONTEXT_EXAMPLE>\s*", "", template)
    prompt = template.replace("<SYNTHESIS_TEXT>", txt_text)
    client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
    if model in ["o4-mini"]:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "user", "content": prompt},
            ],
        )
    else:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "user", "content": prompt},
            ],
            temperature=0.0,
        )
    result = response.choices[0].message.content
    prompt_tokens = response.usage.prompt_tokens
    completion_tokens = response.usage.completion_tokens
    total_tokens = response.usage.total_tokens
    return result, prompt_tokens, completion_tokens, total_tokens


def write_json_output(result, output_path):
    try:
        parsed = validate_jsonld(json.loads(result))
        parsed = normalize_jsonld(parsed)
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(parsed, f, ensure_ascii=False, indent=2)
    except Exception as e:
        print(
            f"Warning: Could not parse output as JSON. Reason: {e}. Skipping output for {output_path}"
        )


def normalize_jsonld(parsed):
    FIXED_CONTEXT = [
        {
            "xsd": "http://www.w3.org/2001/XMLSchema#",
            "prov": "http://www.w3.org/ns/prov#",
        },
        "https://openprovenance.org/prov-jsonld/context.jsonld",
        "URL of MatPROV's context schema omitted for double-blind review",
    ]

    param_key_map = {
        "matprov:width": "matprov:length_width",
        "matprov:height": "matprov:length_height",
        "matprov:thickness": "matprov:length_thickness",
        "matprov:diameter": "matprov:length_diameter",
    }

    def normalize_param_keys(params):
        normalized = {}
        for k, v in params.items():
            normalized_key = param_key_map.get(k, k)
            normalized[normalized_key] = v
        return normalized

    def enrich_node(node):
        if node.get("@type") in ["Entity", "Activity"]:
            # Normalize label language
            if "label" in node:
                for label in node["label"]:
                    if "@language" not in label:
                        label["@language"] = "EN"
            # Normalize matprov:* keys
            matprov_keys = {k: v for k, v in node.items() if k.startswith("matprov:")}
            if matprov_keys:
                normalized_matprov = normalize_param_keys(matprov_keys)
                # Remove old keys and update with normalized keys
                for k in matprov_keys:
                    node.pop(k)
                for k, v in normalized_matprov.items():
                    node[k] = v
            # Add xsd:string type if missing
            for k, v in node.items():
                if k.startswith("matprov:") and isinstance(v, list):
                    for item in v:
                        if "@type" not in item:
                            item["@type"] = "xsd:string"
        return node

    normalized = []
    for elem in parsed:
        new_elem = OrderedDict()
        new_elem["@context"] = FIXED_CONTEXT
        for k, v in elem.items():
            if k == "@context":
                continue
            elif k == "@graph":
                new_elem["@graph"] = [enrich_node(n) for n in v]
            else:
                new_elem[k] = v
        normalized.append(new_elem)
    return normalized


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input-dir",
        type=str,
        required=True,
        help="Directory containing synthesis-related TXT files.",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        required=True,
        help="Directory to save extracted JSON outputs",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="o4-mini",
        help="OpenAI model name for LLM-based extraction",
    )
    parser.add_argument(
        "--template-path",
        type=str,
        required=True,
        help="Prompt template file path for LLM-based extraction",
    )
    parser.add_argument(
        "--example-path",
        type=str,
        default=None,
        help="Path to one-shot example text file",
    )
    args = parser.parse_args()

    input_dir = args.input_dir
    output_dir = args.output_dir
    model = args.model
    template_path = args.template_path
    example_path = args.example_path
    total_prompt_tokens = 0
    total_completion_tokens = 0
    total_tokens = 0

    txt_files = [
        os.path.join(input_dir, f)
        for f in os.listdir(input_dir)
        if f.endswith("_llm.txt")
    ]
    txt_files.sort()
    for txt_file in tqdm(txt_files):
        base = os.path.splitext(os.path.basename(txt_file))[0][:-4]
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, base + ".json")
        if os.path.exists(output_path):
            print(f"Skipped: {output_path} (already exists)")
            continue
        result, prompt_tokens, completion_tokens, tokens = extract_synthesis_procedure(
            txt_file, template_path, model, example_path
        )
        write_json_output(result, output_path)
        total_prompt_tokens += prompt_tokens
        total_completion_tokens += completion_tokens
        total_tokens += tokens
    print(f"Cumulative prompt tokens: {total_prompt_tokens}")
    print(f"Cumulative response tokens: {total_completion_tokens}")
    print(f"Cumulative total tokens: {total_tokens}")


if __name__ == "__main__":
    main()
