import os
from pathlib import Path
import argparse
from openai import OpenAI
from tqdm import tqdm


def load_text_file(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        return f.read()


def openai_paragraph_extraction(text, model, prompt_template_path):
    prompt_template = load_text_file(prompt_template_path)
    prompt = prompt_template.replace("<PAPER_TEXT>", text)
    client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "user", "content": prompt},
        ],
        temperature=0,
    )
    result = response.choices[0].message.content
    prompt_tokens = response.usage.prompt_tokens
    completion_tokens = response.usage.completion_tokens
    total_tokens = response.usage.total_tokens
    return result, prompt_tokens, completion_tokens, total_tokens


def main(input_dir, output_dir, model=None, prompt_template_path=None):
    txt_files = [
        p for p in Path(input_dir).glob("*.txt") if not p.name.endswith("_llm.txt")
    ]
    os.makedirs(output_dir, exist_ok=True)
    print(f"Found {len(txt_files)} TXT files in {input_dir}")
    llm_not_found_count = 0
    total_prompt_tokens = 0
    total_completion_tokens = 0
    total_tokens = 0

    for txt_file in tqdm(txt_files):
        file_name = Path(txt_file).stem
        llm_out_path = os.path.join(output_dir, f"{file_name}_llm.txt")
        if os.path.exists(llm_out_path):
            print(f"Skipped: {llm_out_path} (already exists)")
            continue
        text = load_text_file(txt_file)
        llm_text, prompt_tokens, completion_tokens, tokens = (
            openai_paragraph_extraction(text, model, prompt_template_path)
        )
        if llm_text:
            with open(llm_out_path, "w", encoding="utf-8") as f:
                f.write(llm_text)
        else:
            print(f"No synthesis-related text found by LLM for: {file_name}")
            llm_not_found_count += 1
        total_prompt_tokens += prompt_tokens
        total_completion_tokens += completion_tokens
        total_tokens += tokens

    total = len(txt_files)
    if total > 0:
        llm_not_found_percent = llm_not_found_count / total * 100
        print(
            f"\nLLM-based extraction found NO synthesis-related text in {llm_not_found_count} / {total} files ({llm_not_found_percent:.1f}%)"
        )
        print(f"Cumulative prompt tokens: {total_prompt_tokens}")
        print(f"Cumulative response tokens: {total_completion_tokens}")
        print(f"Cumulative total tokens: {total_tokens}")
    else:
        print("No TXT files found.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extract synthesis-related text from TXT files using LLM."
    )
    parser.add_argument(
        "--input-dir", type=str, required=True, help="Directory containing TXT files."
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        required=True,
        help="Directory to save extracted text files.",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="gpt-4o-mini",
        help="OpenAI model name for LLM-based extraction",
    )
    parser.add_argument(
        "--template-path",
        type=str,
        required=True,
        help="Prompt template file path for LLM-based extraction",
    )
    args = parser.parse_args()
    main(
        args.input_dir,
        args.output_dir,
        model=args.model,
        prompt_template_path=args.template_path,
    )
