
import os
import json
import openai
import re
from pylatexenc.latex2text import LatexNodes2Text

class ExtractionAgent:
    def __init__(self, schema_path, fewshot_tex_path, fewshot_json_path, api_key, api_base, model="qwen2.5-14b-instruct"):
        with open(schema_path, 'r', encoding='utf-8') as f:
            self.prompt_schema = json.load(f)
        self.fewshot_tex_path = fewshot_tex_path
        self.fewshot_json_path = fewshot_json_path
        self.api_key = api_key
        self.api_base = api_base
        self.model = model

    def read_tex_file(self, tex_path):
        with open(tex_path, 'r', encoding='utf-8', errors='ignore') as f:
            return f.read()

    def preprocess_tex(self, tex_content):
        tex_content = re.sub(r'%.*', '', tex_content)
        tex_content = re.sub(r'\\(usepackage|bibliographystyle|bibliography|inputenc|documentclass|IEEEoverridecommandlockouts|makeatletter)[^\n]*', '', tex_content)
        tex_content = re.split(r'\\begin\{thebibliography\}', tex_content)[0]
        tex_content = re.split(r'\\bibliography\{', tex_content)[0]
        tex_content = re.sub(r'\s+', ' ', tex_content)
        return tex_content

    def load_fewshot_example(self):
        with open(self.fewshot_tex_path, 'r', encoding='utf-8', errors='ignore') as f:
            tex_lines = []
            for i, line in enumerate(f):
                if i >= 300:
                    break
                tex_lines.append(line)
            tex_content = ''.join(tex_lines)
        with open(self.fewshot_json_path, 'r', encoding='utf-8') as f:
            json_content = f.read()
        tex_text = LatexNodes2Text().latex_to_text(tex_content)
        return tex_text, json_content

    def build_prompt(self, tex_content):
        schema_str = json.dumps(self.prompt_schema, ensure_ascii=False, indent=2)
        prompt = (
            f"Schema:\n{schema_str}\n"
            "LaTeX content:\n"
            f"{tex_content}\n"
            "Please strictly follow the schema keys to output JSON, try to fill content and avoid null values."
        )
        return prompt

    def call_openai(self, prompt):
        openai.api_key = self.api_key
        openai.api_base = self.api_base
        openai.requestssession = __import__("requests").Session()
        fewshot_tex, fewshot_json = self.load_fewshot_example()
        system_prompt = (
            "You are an AI assistant skilled at information extraction. Please refer to the following example and its structured extraction results. I will give you new LaTeX paper content, please output JSON in the same format as the example.\n"
            "[Example LaTeX Content]\n" + fewshot_tex + "\n"
            "[Example Extraction Result]\n" + fewshot_json + "\n"
            "-----------------------------\n"
            "For each subsequent request, I will only give you new LaTeX content and schema."
        )
        response = openai.ChatCompletion.create(
            model=self.model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt}
            ],
            temperature=0.2,
            max_tokens=2048
        )
        text = response['choices'][0]['message']['content']
        match = re.search(r'\{[\s\S]*\}', text)
        if match:
            return json.loads(match.group(0))
        else:
            raise ValueError("Failed to extract JSON from model output.")

    def extract_structured_info(self, tex_path):
        tex_content = self.read_tex_file(tex_path)
        tex_content = self.preprocess_tex(tex_content)
        tex_text = LatexNodes2Text().latex_to_text(tex_content)
        prompt = self.build_prompt(tex_text)
        return self.call_openai(prompt)

    def batch_extract(self, single_tex_dir, output_dir, log_file):
        os.makedirs(output_dir, exist_ok=True)
        with open(log_file, 'w', encoding='utf-8') as log_f:
            for root, dirs, files in os.walk(single_tex_dir):
                for file in files:
                    if file.endswith('.tex'):
                        tex_path = os.path.join(root, file)
                        rel_path = os.path.relpath(tex_path, single_tex_dir)
                        json_name = rel_path.replace(os.sep, '_').replace('.tex', '.json')
                        json_path = os.path.join(output_dir, json_name)
                        try:
                            output = self.extract_structured_info(tex_path)
                            with open(json_path, 'w', encoding='utf-8') as out_f:
                                json.dump(output, out_f, ensure_ascii=False, indent=2)
                            print(f"[OK] {tex_path} -> {json_path}")
                        except Exception as e:
                            log_f.write(f"[ERROR] {tex_path}: {str(e)}\n")
                            print(f"[ERROR] {tex_path}: {str(e)}")

if __name__ == "__main__":
    agent = ExtractionAgent(
        schema_path="extraction_prompt.json",
        fewshot_tex_path="Typical-Example/information-extraction/2305.00673.tex",
        fewshot_json_path="Typical-Example/information-extraction/2305.00673.json",
        api_key="EMPTY",
        api_base="http://localhost:8000/v1"
    )
    agent.batch_extract(
        single_tex_dir="single_tex",
        output_dir="extracted_content",
        log_file=os.path.join("extracted_content", "extraction_error.log")
    )