from extractors import (extract_bookmarks, extract_sections_content, get_sections_content, extract_title,
                        extract_references, extract_related_work)
from api_client import ask_for_openai
from ref_in_rw import extract_refs_from_rw
import json
import os


def process_pdf(pdf_file, output_dir):
    # 获取PDF文件的文件名（不包括扩展名）
    base_name = os.path.splitext(os.path.basename(pdf_file))[0]

    # 创建以PDF文件名为名称的文件夹
    pdf_output_dir = os.path.join(output_dir, base_name)
    os.makedirs(pdf_output_dir, exist_ok=True)

    # 写入 title.txt
    title = extract_title(pdf_file)
    with open(os.path.join(pdf_output_dir, "title.txt"), "w", encoding="utf-8") as title_file:
        title_file.write(title)

    # 提取目录
    section_titles = extract_bookmarks(pdf_file)

    # 写入 references.txt
    raw_references = extract_references(pdf_file, section_titles)
    with open(os.path.join(pdf_output_dir, "reference.txt"), "w", encoding="utf-8") as reference_file:
        for reference in raw_references:
            reference_file.write(reference + '\n')

    # 提取内容并生成content.json文件

    # 写入 bookmarks.txt
    with open(os.path.join(pdf_output_dir, "bookmarks.txt"), "w", encoding="utf-8") as bookmarks_file:
        for section_title in section_titles:
            # section_title 是 (level, title, page) 的元组，我们需要将其拼接成字符串，用“+”号连接
            section_title = '+'.join([str(i) for i in section_title])
            bookmarks_file.write(section_title + '\n')

    content_dict = extract_sections_content(pdf_file, section_titles)
    related_work = {}
    have_found = False
    all_found = False
    with open(os.path.join(pdf_output_dir, "content.json"), "w", encoding="utf-8") as content_file:
        for section_title, section_content in content_dict.items():
            # section_title 是 (level, title, page) 的元组，我们需要将其拼接成字符串，用“+”号连接
            section_title = '+'.join([str(i) for i in section_title])
            content_file.write(json.dumps({section_title: section_content}, ensure_ascii=False) + '\n')

            if 'related work' in section_title.lower() and not have_found:
                level = int(section_title.split('+')[0])
                have_found = True
                related_work["1+None+1"] = section_content
            elif have_found and not all_found:
                if int(section_title.split('+')[0]) > level:
                    # section_content开头部分就是section_title, 去掉第一个\n之前的部分
                    related_work[section_title] = section_content[section_content.find('\n') + 1:]
                else:
                    all_found = True

    if not related_work:
        for section_title, section_content in content_dict.items():
            if 'related work' in section_content.lower():
                related_work_content = section_content
                # 读取 sample_related_work.json 文件
                with open("sample_related_work.json", "r", encoding="utf-8") as sample_related_work_file:
                    sample_related_work = sample_related_work_file.read()
                # 把这段代码送给OpenAI作为内容, 中途加入sample_related_work.json的内容, 要求它去除我给他的text中的换行符(变成空格)
                pre_prompt = (
                    f"Here is a JSON file used to represent each small part of Related Work of \n\n{title}\n\n. "
                    f"Please process the text I will give you like this, "
                    f"without outputting any extra characters:"
                    f"\n\n"
                    f"{sample_related_work}"
                    f"\n\n"
                    f"Here is the text you need to process:"
                    f"\n\n")
                answer = ask_for_openai(pre_prompt + related_work_content)
                # 直接把answer写入related_work.json文件
                with open(os.path.join(pdf_output_dir, "related_work.json"), "w",
                          encoding="utf-8") as related_work_file:
                    related_work_file.write(answer + '\n')


    else:
        # 提取相关工作并生成related_work.json文件
        with open(os.path.join(pdf_output_dir, "related_work.json"), "w", encoding="utf-8") as related_work_file:
            # 两层的层级结构, related work会包含很多其他文章, 都有他们的标题和内容
            # json文件就是有subtitle和paragraph的结构, 比如{{"subtitle": "related work", "paragraph": "content"}, ...}
            total_data = []
            for section_title, section_content in related_work.items():
                # section_title是"1+KANs are interpretable+19"的形式
                # 将section_content中的'\n'替换为' '，并去掉首尾空格
                data = {"subtitle": section_title.split('+')[1], "paragraph": section_content.replace('\n', ' ').strip()}
                total_data.append(data)
            related_work_file.write(json.dumps(total_data, ensure_ascii=False) + '\n')

    # 从reference.txt中提取出参考文献 使其变成一个列表
    references = []
    with open(os.path.join(pdf_output_dir, "reference.txt"), "r", encoding="utf-8") as reference_file:
        for line in reference_file:
            references.append(line.strip())

    # 去除前面的[1] [2]等标记 (如果有的话, 标记has_prefix)
    has_prefix = False
    if references[0][0] == '[':
        has_prefix = True
    if has_prefix:
        references = [reference[reference.find(']') + 1:] for reference in references]
    # else:
    #     # 交给OpenAI处理, 里面可能有一些不是参考文献的内容,要提醒它去掉, 生成一个列表
    #     # 格式如下:
    #     # ["Prafulla Dhariwal and Alexander Nichol. Diffusion modelsbeat GANs on image synthesis. NeurIPS, pages 8780–8794,2021. 2", ...]
    #     pre_prompt = (
    #         f"Here is a list of references from the paper {title}. Please remove any non-reference content and "
    #         f"return a list of references in the following format, must begin with [ and end with ]:"
    #         f"\n\n"
    #         f"[\"Prafulla Dhariwal and Alexander Nichol. Diffusion modelsbeat GANs on image synthesis. NeurIPS, pages 8780–8794,2021. 2\", ...]"
    #         f"\n\n"
    #         f"Here is the list of references:\n\n")
    #     references = ask_for_openai(pre_prompt + '\n'.join(references))

    # 写入 references.json
    with open(os.path.join(pdf_output_dir, "references.json"), "w", encoding="utf-8") as references_file:
        references_file.write(json.dumps(references, ensure_ascii=False) + '\n')

    # 提取相关工作中的参考文献
    related_work_references = extract_refs_from_rw(related_work, raw_references)
    with open(os.path.join(pdf_output_dir, "related_work_references.json"), "w", encoding="utf-8") as related_work_references_file:
        related_work_references_file.write(json.dumps(related_work_references, ensure_ascii=False) + '\n')

    # 输出提示信息
    print(f"已处理 {pdf_file}, 共计 {len(section_titles)} 章节。")