import fitz
import os


# 提取参考文献
def extract_references(pdf_path, section_titles) -> list:
    """
    提取PDF中的参考文献。
    参数:
    - doc: fitz.Document，PDF文档对象
    - section_titles: list，目录信息, 例如[(1, 'Introduction', 2), ...] 其中1是目录级别，'Introduction'是目录标题，2是目录所在的页码
    返回值:
    - list: 参考文献信息列表
    """
    doc = fitz.open(pdf_path)
    page_num = len(doc)
    ref_list = []
    for num, page in enumerate(doc):
        content = page.get_text('blocks')
        for pc in content:
            txt_blocks = list(pc[4:-2])
            txt = ''.join(txt_blocks)
            if 'References' in txt or 'REFERENCES' in txt or 'referenCes' in txt:
                # 结束部分就是找到不小于这个页码的最小的目录页码
                start_page = num
                end_page = page_num
                # 倒序枚举section_titles, 如果section_titles[2] > start_page, 那么end_page = section_titles[2]
                for i, st in enumerate(section_titles[::-1]):
                    if st[2] > start_page + 1:
                        end_page = st[2]
                    else:
                        break
                ref_num = [i for i in range(start_page, end_page)]
                break

    for rpn in ref_num:
        ref_page = doc[rpn]
        ref_content = ref_page.get_text('blocks')
        # print("ref_content:", ref_content)
        for refc in ref_content:
            txt_blocks = list(refc[4:-2])
            # 如果这一行含有制表符，那么就是表格
            if '\t' in ''.join(txt_blocks):
                continue
            # print("txt_blocks:", txt_blocks)
            ref_list.extend(txt_blocks)

    index = 0
    for i, ref in enumerate(ref_list):
        if 'References' in ref or 'REFERENCES' in ref or 'referenCes' in ref:
            index = i
            break
    if index + 1 < len(ref_list):
        index += 1
    references = [ref.replace('\n', '') for ref in ref_list[index:] if len(ref) > 10]

    # 如果 reference 中的 [ 和 ] 加起来超过 10 个，那么认为是 [1] 这种格式
    if sum([ref.count('[') + ref.count(']') for ref in references]) > 20:
        # 如果本行的开头不是[ + 数字的类型，或者结尾不是句号，那么删除
        references = [ref for ref in references if (ref[0] == '[' and ref[1].isdigit()) or ref[-1] == '.']
        # 如果开头不属于[ + 数字的类型，那么和上一行合并，倒序
        for i in range(len(references) - 1, 0, -1):
            if not (references[i][0] == '[' and references[i][1].isdigit()):
                references[i - 1] = references[i - 1] + ' ' + references[i]
                references[i] = ''
        # 删除空行
        references = [ref for ref in references if ref]
    else:
        # 用一个数组标记每一行是不是参考文献的开头
        is_ref = [False] * len(references)

        for i, ref in enumerate(references):
            # 首先判断这一行是不是参考文献的开头
            # 如果有句号, 判断的时候只考虑句号前面的内容
            if '.' in ref:
                ref = ref[:ref.index('.')]
            # 如果这一行前面的所有单词的首字母全部大写，那么认为是参考文献的开头 (除了 and 这个单词可以是小写), 否则 continue
            words = ref.split(' ')
            if all([word[0].isupper() or word == 'and' for word in words]):
                is_ref[i] = True
                # print("ref:", ref)


        # 如果这一行不是参考文献的开头，那么和上一行合并 (倒序)
        for i in range(len(references) - 1, 0, -1):
            if not is_ref[i]:
                references[i - 1] = references[i - 1] + ' ' + references[i]
                references[i] = ''
                # 如果某个单词除了首字母的中间某个位置存在有且仅有一个大写字母，将其拆分, 比如TomJerry -> Tom Jerry, TomJErry -> TomJErry(不变)
                # words = references[i - 1].split(' ')

        # 删除空行
        references = [ref for ref in references if ref]
        # print("references:", references)


    # 最后如果发现reference的数量过少，那么本文的引文格式可能不是标准的，需要使用AI进行提取
    # if len(references) < 10:
    #     print("The number of references is less than 10, which may not be a standard reference format.")
    #     print("We will use AI to extract references.")
    # 使用AI提取参考文献)

    doc.close()
    return references
