#!/usr/bin/env python3
import os
import json
import re
import openai
import pdfplumber
from pathlib import Path
from json_repair import repair_json
from pylatexenc.latex2text import LatexNodes2Text

import logging
logging.getLogger("pdfminer").setLevel(logging.ERROR)


class ExtractionPdfAgent:
    def __init__(self, schema_path, fewshot_tex_path, fewshot_json_path, api_key, api_base, model="qwen2.5-14b-instruct"):
        with open(schema_path, 'r', encoding='utf-8') as f:
            self.prompt_schema = json.load(f)
        self.fewshot_tex_path = fewshot_tex_path
        self.fewshot_json_path = fewshot_json_path
        self.api_key = api_key
        self.api_base = api_base
        self.model = model

    def pdf_to_text(self, pdf_path):
        import pdfplumber
        text_parts = []
        try:
            with pdfplumber.open(pdf_path) as pdf:
                for page in pdf.pages:
                    try:
                        text = page.extract_text()
                        if text:
                            text_parts.append(text)
                    except Exception as e:
                        print(f"[WARN] Skip page error: {e}")
                        continue
        except Exception as e:
            print(f"[ERROR] PDF cannot be read: {pdf_path}, error: {e}")
            raise ValueError(f"PDF file corrupted or format not supported: {e}")
        if not text_parts:
            raise ValueError("PDF file is empty or cannot extract text")
        return "\n".join(text_parts).strip()

    def read_input_file(self, path):
        from pathlib import Path
        suffix = Path(path).suffix.lower()
        if suffix == '.pdf':
            return self.pdf_to_text(path)
        elif suffix == '.tex':
            with open(path, encoding='utf-8', errors='ignore') as f:
                return f.read()
        else:
            raise ValueError(f"Unsupported file type: {suffix}")

    def preprocess_text(self, text):
        MAX_TOKENS_EST = 60000
        text = re.sub(r'\s+', ' ', text)
        if len(text) > MAX_TOKENS_EST:
            print(f"[WARN] Text too long ({len(text)}), truncating to {MAX_TOKENS_EST}")
            text = text[:MAX_TOKENS_EST]
        return text.strip()

    def load_fewshot_example(self):
        from pylatexenc.latex2text import LatexNodes2Text
        tex_lines = []
        with open(self.fewshot_tex_path, encoding='utf-8', errors='ignore') as f:
            for i, line in enumerate(f):
                if i >= 300:
                    break
                tex_lines.append(line)
        tex_content = ''.join(tex_lines)
        tex_plain = LatexNodes2Text().latex_to_text(tex_content)
        with open(self.fewshot_json_path, encoding='utf-8') as f:
            json_content = f.read()
        return tex_plain, json_content

    def build_prompt(self, content_plain):
        schema_str = json.dumps(self.prompt_schema, ensure_ascii=False, indent=2)
        return (
            f"Schema:\n{schema_str}\n\n"
            f"Paper content:\n{content_plain}\n\n"
            f"Please strictly follow the schema keys to output JSON, try to fill content and avoid null values."
        )

    def call_llm(self, prompt_text):
        import openai
        from json_repair import repair_json
        openai.api_key = self.api_key
        openai.api_base = self.api_base
        tex_plain, json_plain = self.load_fewshot_example()
        system_prompt = (
            "You are an information extraction expert. Please refer to the following example and its extraction results, and perform the same format of structured output for new paper content.\n"
            f"[Example LaTeX Plain Text]\n{tex_plain}\n\n"
            f"[Example Extraction JSON]\n{json_plain}\n"
            "-----------------------------\n"
            "Below is a new paper, please strictly output JSON according to the schema format."
        )
        resp = openai.ChatCompletion.create(
            model=self.model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt_text}
            ],
            temperature=0.2,
            max_tokens=2048
        )
        raw = resp['choices'][0]['message']['content']
        try:
            return json.loads(repair_json(raw))
        except Exception as e:
            raise ValueError(f"JSON repair failed: {e}")

    def extract_structured_info(self, paper_path):
        raw = self.read_input_file(paper_path)
        clean = self.preprocess_text(raw)
        prompt = self.build_prompt(clean)
        return self.call_llm(prompt)

    def batch_extract(self, input_dir, output_dir, log_file):
        from pathlib import Path
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        files = list(Path(input_dir).rglob("*.tex")) + list(Path(input_dir).rglob("*.pdf"))
        with open(log_file, 'w', encoding='utf-8') as log_f:
            for file in files:
                rel = file.relative_to(input_dir)
                json_name = str(rel).replace(os.sep, '_').replace('.tex', '.json').replace('.pdf', '.json')
                json_path = Path(output_dir) / json_name
                try:
                    result = self.extract_structured_info(str(file))
                    with open(json_path, 'w', encoding='utf-8') as out_f:
                        json.dump(result, out_f, ensure_ascii=False, indent=2)
                    print(f"[OK] {file} -> {json_path}")
                except Exception as e:
                    log_f.write(f"[ERROR] {file}: {e}\n")
                    print(f"[ERROR] {file}: {e}")

if __name__ == "__main__":
    agent = ExtractionPdfAgent(
        schema_path="extraction_prompt.json",
        fewshot_tex_path="Typical-Example/information-extraction/2305.00673.tex",
        fewshot_json_path="Typical-Example/information-extraction/2305.00673.json",
        api_key="EMPTY",
        api_base="http://localhost:8000/v1"
    )
    agent.batch_extract(
        input_dir="",
        output_dir="",
        log_file=""
    )