import json, os, re
from typing import List

# Bookmark and titles
def check_if_section_title(line, title):
    # 参数: title: str 需要检查的章节标题
    #       line: str 当前行
    # 返回: bool 是否是章节标题, match: 正则表达式匹配结果
    # 如果 line_lower 以 title 结尾
    def simplify_string(s):
        return s.strip().lower().replace(" ", "").replace("’", "'")
    
    simplified_line = simplify_string(line)
    simplified_title = simplify_string(title)
    is_section_title = False
    match = None

    if simplified_line.endswith(simplified_title):
        # 前面不超过2个字符，那么认为是章节标题
        if len(simplified_line) - len(simplified_title) <= 2:
            is_section_title = True
            simplified_line = simplified_title
        # 前面的内容满足 [数字.数字.] 其中最后一个. 可以没有, 也认为是章节标题
        else:
            # 用正则表达式匹配 (\d+\.)*\d+\.?)
            match = re.match(r'((\d+\.)*\d+\.?)\s+(.*)', line)
            if match:
                is_section_title = True
                simplified_line = simplified_title

    # 如果 line_lower 中含有 title，并且后面跟一个冒号，并且前面不超过2个字符，那么认为是章节标题
    else:
        with_colon = title + ':'
        if with_colon in simplified_line and simplified_line.index(with_colon) <= 2:
            is_section_title = True
            simplified_title = simplified_title

    return is_section_title, title, match
def extract_section_titles(doc, output_dir=None, write_section_titles=False):
    
    # 如果有自带书签，直接返回
    section_titles = doc.get_toc()
    if section_titles:
        # doc.close()
        return section_titles
    section_titles = []
    # 如果没有自带书签，那么书签前面的数字是可以被提取出来的(否则检测不出来)
    for page_num in range(doc.page_count):
        page = doc.load_page(page_num)
        text = page.get_text()
        lines = text.split('\n')
        last_line = ""
        for line in lines:
            line = line.strip()


            # 如果能match (\d+\.)*\d+\.?) 后面没有内容, 说明需要和下一行合并
            if re.match(r'((\d+\.)*\d+\.?)', last_line):
                line = last_line + ' ' + line
                # print("last_line", last_line)
                last_line = ""
            else:
                last_line = line

            match = re.match(r'((\d+\.)*\d+\.?)\s+(.*)', line)
            if not match:
                continue
            # 如果match.group(3)中包含". 大写字母", 那么continue
            if re.search(r'\.\s+[A-Z]', match.group(3)):
                continue

            level = len(match.group(1).split('.'))
            # 如果match.group(1).split('.')最后一项是空的，level 减 1
            if match.group(1).split('.')[-1] == '':
                level -= 1
            # title = match.group(3)
            title = line

            if not section_titles:
                # 如果是第一个标题，必须是 1 或 1. 开头
                if re.match(r'^(1\.?)\s+(.*)', line):
                    section_titles.append([level, title, page_num])
                continue

            last_level, last_title, _ = section_titles[-1]
            if level > last_level + 1:
                continue
            elif level == last_level + 1:
                # 如果是上一级的标题，检验对应的前缀，比如 1. 会出现在 1.1. 前面，可以直接按照空格split然后判断
                last_title_split = last_title.split(' ')[0].split('.')[0] + '.1'
                if not title.startswith(last_title_split):
                    continue
                # 输出各个部分match.group(1), match.group(2), match.group(3)
                # print(match.group(1), match.group(2), match.group(3))
                # print("level, title, page_num:", level, title, page_num)
                # print("last_level, last_title:", last_level, last_title)
            elif level == last_level:
                # 如果是同级的标题，检验对应的前缀，比如 1.1. 和 1.2. ，最后一个数字之前的部分应该是一样的，最后一个数字相差 1
                # 先按照空格split，然后按照. split，然后判断最后一项是否相差 1
                last_title_split = last_title.split(' ')[0].split('.')
                title_split = title.split(' ')[0].split('.')
                if len(last_title_split) != len(title_split):
                    continue
                # 删除最后的空字符串 (如果有的话)
                if last_title_split[-1] == '':
                    last_title_split = last_title_split[:-1]
                if title_split[-1] == '':
                    title_split = title_split[:-1]
                # 比较最后一项和前缀
                if last_title_split[:-1] != title_split[:-1] or int(last_title_split[-1]) + 1 != int(title_split[-1]):
                    continue
            elif level < last_level:
                # 如果是上一级的标题，检验对应的前缀，比如 2. 会出现在 1.1. 后面，可以参照上面的方法
                last_title_split = last_title.split(' ')[0].split('.')
                title_split = title.split(' ')[0].split('.')
                if len(last_title_split) != len(title_split) + 1:
                    continue
                # 删除最后的空字符串 (如果有的话)
                if last_title_split[-1] == '':
                    last_title_split = last_title_split[:-1]
                if title_split[-1] == '':
                    title_split = title_split[:-1]
                # 比较前缀
                if int(last_title_split[-2]) + 1 != int(title_split[-1]):
                    continue

            section_titles.append([level, title, page_num])

    # 去除 title 中的第一项，保持后面的不变
    section_titles = [[level, ' '.join(title.split(' ')[1:]), page_num] \
                 for level, title, page_num in section_titles]
    # 去除title中最后的空格(如果有的话)
    section_titles = [[level, title.rstrip(), page_num] \
                 for level, title, page_num in section_titles]

    # doc.close()
    # 写入 section_titles.txt
    if write_section_titles:
        with open(os.path.join(output_dir, "section_titles.txt"), "w", encoding="utf-8") as section_titles_file:
            for section_title in section_titles:
                # section_title 是 (level, title, page) 的元组，我们需要将其拼接成字符串，用“+”号连接
                section_title = '+'.join([str(i) for i in section_title])
                section_titles_file.write(section_title + '\n')
    
    return section_titles


# related work
def extract_related_work(doc, section_titles) -> List:
    
    def simplify_string(s):
        return s.strip().lower().replace(" ", "").replace("’", "'")
    
    def check_if_section_title(line, title):
        # 参数: title: str 需要检查的章节标题
        #       line: str 当前行
        # 返回: bool 是否是章节标题, match: 正则表达式匹配结果
        # 如果 line_lower 以 title 结尾

        simplified_line = simplify_string(line)
        simplified_title = simplify_string(title)
        is_section_title = False
        match = None

        if simplified_line.endswith(simplified_title):
            # 前面不超过2个字符，那么认为是章节标题
            if len(simplified_line) - len(simplified_title) <= 2:
                is_section_title = True
                simplified_line = simplified_title
            # 前面的内容满足 [数字.数字.] 其中最后一个. 可以没有, 也认为是章节标题
            else:
                # 用正则表达式匹配 (\d+\.)*\d+\.?)
                match = re.match(r'((\d+\.)*\d+\.?)\s+(.*)', line)
                if match:
                    is_section_title = True
                    simplified_line = simplified_title

        # 如果 line_lower 中含有 title，并且后面跟一个冒号，并且前面不超过2个字符，那么认为是章节标题
        else:
            with_colon = title + ':'
            if with_colon in simplified_line and simplified_line.index(with_colon) <= 2:
                is_section_title = True
                simplified_title = simplified_title

        return is_section_title
        
    num_pages = len(doc)
    filtered_section_titles = [simplify_string(title[1]) for title in section_titles]

    rw_title = ""
    next_section_title = ""
    rw_content = []
    for i, fst in enumerate(filtered_section_titles):
        if "relatedwork" in fst:
            rw_title = fst
            next_section_title = filtered_section_titles[i+1]
            break
    try:
        assert rw_title, "Conflict: Cannot find 'related work' in section titles" # usually this will not happen
    except:
        # import pdb; pdb.set_trace()
        print(1111)
    # get the start index and page number
    start_line_ind = -1
    start_page_ind = -1
    for page_num in range(num_pages):
        page = doc[page_num]
        page_text = page.get_text()
        lines = page_text.split('\n')
        
        for i, line in enumerate(lines):
            line_lower = simplify_string(line)
            
            if "relatedwork" in line_lower and check_if_section_title(line_lower, rw_title): # complete title in line
                # import pdb; pdb.set_trace()
                start_line_ind = i
                start_page_ind = page_num
                break
    
    assert (start_line_ind >=0 and start_page_ind >= 0), "Conflict: Cannot find content of 'related work' in pages"
    # get the content of the related work
    rw_content = []
    isenough = False
    # process the first page
    remaining_lines = doc[start_page_ind].get_text().split('\n')[start_line_ind: ]
    for line in remaining_lines:
        line_lower = simplify_string(line)
        # stop collecting when the next section title occurs
        if next_section_title in line_lower and check_if_section_title(line_lower, next_section_title): 
            isenough = True
            break
        rw_content.append(line)
        
    if not isenough:
        for page in doc[start_page_ind+1: ]:
            lines = page.get_text().split('\n')
            for line in lines:
                line_lower = simplify_string(line)
                # stop collecting when the next section title occurs
                if next_section_title in line_lower and check_if_section_title(line_lower, next_section_title): 
                    isenough = True
                    break
                rw_content.append(line)
            if isenough:
                break             
    return rw_content

# Reference
def check_if_APA(line: str):
    # 参数: line: str RF的第一行
    # 返回: bool 是否是APA格式 (否则是IEEE格式)

    # 如果开头是方括号 + 数字 + 方括号, 那么是IEEE格式
    if re.match(r'\[\d+\]', line):
        return False

    return True
def extract_references(doc, section_titles, output_dir=None, write_references=False):
    """
    提取PDF中的参考文献。
    参数:
    - doc: fitz.Document, PDF文档对象
    - section_titles: list, 目录信息, 例如[(1, 'Introduction', 2), ...] 其中1是目录级别, 'Introduction'是目录标题, 2是目录所在的页码
    返回值:
    - list: 参考文献信息列表
    """
    def is_next_title(titles, content):
        c_list = content.split(' ')
        if len(c_list) > 10:
            return False
        for t in range(len(titles)):
            if titles[t][1].lower() in content.lower():
                # import pdb; pdb.set_trace()
                return True
    
    page_num = len(doc)
    ref_list = []
    for num, page in enumerate(doc):
        content = page.get_text('blocks')
        for pc in content:
            txt_blocks = list(pc[4:-2])
            txt = ''.join(txt_blocks)
            # import pdb; pdb.set_trace()
            if 'References' in txt or 'REFERENCES' in txt or 'referenCes' in txt:
                # 结束部分就是找到不小于这个页码的最小的目录页码
                start_page = num
                end_page = page_num
                # 倒序枚举section_titles, 如果section_titles[2] > start_page, 那么end_page = section_titles[2]
                for i, st in enumerate(section_titles[::-1]):
                    # import pdb; pdb.set_trace()
                    if st[2] > start_page + 1:
                        end_page = st[2]
                    else:
                        break
                ref_num = [i for i in range(start_page, end_page)]
                break

    for rpn in ref_num:
        ref_page = doc[rpn]
        ref_content = ref_page.get_text('blocks')
        # print("ref_content:", ref_content)
        for refc in ref_content:
            txt_blocks = list(refc[4:-2])
            # 如果这一行含有制表符，那么就是表格
            if '\t' in ''.join(txt_blocks):
                continue
            # import pdb; pdb.set_trace()
            # print("txt_blocks:", txt_blocks)
            ref_list.extend(txt_blocks)
    
    
    index = 0
    for i, ref in enumerate(ref_list):
        if 'References' in ref or 'REFERENCES' in ref or 'referenCes' in ref:
            index = i
            break
    if index + 1 < len(ref_list):
        index += 1
    
    references = []
    for ref in ref_list[index: ]:
        n_ref = ref.replace('\n', '')
        if is_next_title(section_titles, n_ref):
            # import pdb; pdb.set_trace()
            break
        else:
            if len(n_ref.split(' ')) > 5:
                references.append(n_ref)
    # import pdb; pdb.set_trace()
    assert references, "Failed to get correct reference content."
    is_APA = check_if_APA(references[0])
    
    # 如果 reference 中的 [ 和 ] 加起来超过 10 个，那么认为是 [1] 这种格式
    if not is_APA:
        # 如果本行的开头不是[ + 数字的类型，或者结尾不是句号，那么删除
        references = [ref for ref in references if (ref[0] == '[' and ref[1].isdigit()) or ref[-1] == '.']
        # 如果开头不属于[ + 数字的类型，那么和上一行合并，倒序
        for i in range(len(references) - 1, 0, -1):
            if not (references[i][0] == '[' and references[i][1].isdigit()):
                # print("references[i]:", references[i])
                references[i - 1] = references[i - 1] + ' ' + references[i]
                references[i] = ''
        # 删除空行
        references = [ref for ref in references if ref]
        # 对于每一行, 再次按照 [ 划分, 防止有的行中有多个参考文献, 将这些放入不同行
        new_references = []
        for ref in references:
            # 继续按照 [ 划分
            ref = ref.strip().split('[')
            # 倒序枚举
            for i in range(len(ref) - 1, 0, -1):
                # 如果这一行的开头不是数字，那么和上一行合并
                if not ref[i][0].isdigit():
                    ref[i - 1] = ref[i - 1] + '[' + ref[i]
                    ref[i] = ''
            # 删除空行
            ref = [r for r in ref if r]
            # 如果这一行的开头是数字，那么认为是一个新的参考文献
            for r in ref:
                new_references.append('[' + r)



        references = new_references

    else:
        # 用一个数组标记每一行是不是参考文献的开头
        is_ref = [False] * len(references)

        for i, ref in enumerate(references):
            # 首先判断这一行是不是参考文献的开头
            # 如果有句号, 判断的时候只考虑句号前面的内容
            if '.' in ref:
                ref = ref[:ref.index('.')]
            # 如果这一行前面的所有单词的首字母全部大写，那么认为是参考文献的开头 (除了 and 这个单词可以是小写), 否则 continue
            words = ref.split(' ')
            if all([word[0].isupper() or word == 'and' for word in words]):
                # import pdb; pdb.set_trace()
                is_ref[i] = True
                # print("ref:", ref)


        # 如果这一行不是参考文献的开头，那么和上一行合并 (倒序)
        for i in range(len(references) - 1, 0, -1):
            if not is_ref[i]:
                references[i - 1] = references[i - 1] + ' ' + references[i]
                references[i] = ''
                # 如果某个单词除了首字母的中间某个位置存在有且仅有一个大写字母，将其拆分, 比如TomJerry -> Tom Jerry, TomJErry -> TomJErry(不变)
                # words = references[i - 1].split(' ')

        # 删除空行
        references = [ref for ref in references if ref]

        #region
        # TODO: 对于APA格式的有时候会出现两行读到一起的情况，需要处理

        # 如果文章中没有 20xx 年份的参考文献，那么和上一行合并
        # for i in range(len(references) - 1, 0, -1):
        #     # 用正则表达式匹配年份 20xx
        #     if not re.search(r'20\d{2}', references[i]):
        #         references[i - 1] = references[i - 1] + ' ' + references[i]
        #         references[i] = ''

        # 删除空行
        # references = [ref for ref in references if ref]

    # 最后如果发现reference的数量过少，那么本文的引文格式可能不是标准的，需要使用AI进行提取
    # if len(references) < 10:
    #     print("The number of references is less than 10, which may not be a standard reference format.")
    #     print("We will use AI to extract references.")
    # 使用AI提取参考文献)

    # doc.close()

    

    # # 根据参考文献的格式来清洗一下数据
    # if not is_APA:
    #     # 如果不是APA格式，那么删除所有的不是以[开头的行
    #     references = [ref for ref in references if not ref.startswith('[')]
    #endregion
    if write_references:
        with open(os.path.join(output_dir, "reference.txt"), "w", encoding="utf-8") as reference_file:
            for reference in references:
                reference_file.write(reference + '\n')
        with open(os.path.join(output_dir, "references.json"), "w", encoding="utf-8") as references_file:
            references_file.write(json.dumps(references, ensure_ascii=False) + '\n')

    return references, is_APA, start_page