import os
import re
import pandas as pd
import orjson

PROBLEMS_DIR = "/fs/cml-projects/E2H/HMMT/problems"

def preprocess_problem(problem_lines):
    new_problem_lines = []
    score = ""
    for line in problem_lines:
        a1 = re.search(r"\\section\*\{Proposed by:", line, flags=0)
        a2 = re.search(r"Proposed by:", line, flags=0)
        line = line[:a1.start()] if a1 else (line[:a2.start()] if a2 else line)
        line = line.strip()
        if len(line):
            if line.startswith("\\begin{enumerate}") or line.startswith("\\end{enumerate}") or line.startswith("\\setcounter{enumi}"):
                continue
            else:
                if line.startswith("\\item "):
                    b1 = re.search(r"\\item \([.*?]\) ", line, flags=0)
                    b2 = re.search(r"\\item \[.*?\] ", line, flags=0)
                    b3 = re.search(r"\\item ", line, flags=0)
                    if len(score)==0:
                        if b1:
                            score = line[b1.start():b1.end()]
                        elif b2:
                            score = line[b2.start()+7:b2.end()-2]
                    line = line[b1.end():] if b1 else (line[b2.end():] if b2 else line[b3.end():])
                new_problem_lines += [line,]

    figure_line_idx = [idx for idx, line in enumerate(new_problem_lines) if line.startswith("\\includegraphics")]
    for idx in figure_line_idx:
        new_problem_lines[idx] = ""
        if new_problem_lines[idx-1].startswith("\\begin{center}") and new_problem_lines[idx+1].startswith("\\end{center}"):
            new_problem_lines[idx-1] = ""
            new_problem_lines[idx+1] = "" 

    footnote_line_idx = [idx for idx, line in enumerate(new_problem_lines) if line.startswith("\\footnotetext")]
    for idx in footnote_line_idx:
        count = idx
        while new_problem_lines[count]!="}":
            new_problem_lines[count] = ""
            count += 1
        new_problem_lines[count] = ""

    return ' '.join([line for line in new_problem_lines if len(line)>0]), score
    

def preprocess_answer(answer_lines):
    if not answer_lines:
        return None
    else:
        if re.match(r'\\section\*\{Answer:(.*?)\}', answer_lines):
            answer_lines = answer_lines[len("\\section*{"):-1]
        if answer_lines.startswith("Answer:"):
            return answer_lines[len("Answer:"):].lstrip()
        else:
            return f"Last_{answer_lines}"
    # return answer_lines
    

def preprocess_solution(solution_lines):
    new_solution_lines = solution_lines
    figure_line_idx = [idx for idx, line in enumerate(new_solution_lines) if line.startswith("\\includegraphics")]
    for idx in figure_line_idx:
        new_solution_lines[idx] = ""
        if new_solution_lines[idx-1].startswith("\\begin{center}") and new_solution_lines[idx+1].startswith("\\end{center}"):
            new_solution_lines[idx-1] = ""
            new_solution_lines[idx+1] = "" 

    footnote_line_idx = [idx for idx, line in enumerate(new_solution_lines) if line.startswith("\\footnotetext")]
    for idx in footnote_line_idx:
        count = idx
        while new_solution_lines[count]!="}":
            new_solution_lines[count] = ""
            count += 1
        new_solution_lines[count] = ""
        
    return ' '.join([line for line in new_solution_lines if len(line)>0])


def parsing_problem(section_lines):
    problem_dict = {}
    answer_line_idx = [idx for idx, line in enumerate(section_lines) if len(re.findall(r'Answer:', line))]
    if answer_line_idx:
        answer_idx = answer_line_idx[0]
        problem_lines = section_lines[:answer_idx]
        answer_lines = section_lines[answer_idx]
        solution_lines = section_lines[answer_idx+1:]
    else:
        problem_lines = section_lines
        answer_lines = []
        solution_lines = []
    problem_dict["answer"] = preprocess_answer(answer_lines)
    problem_dict["problem"], problem_dict["score"] = preprocess_problem(problem_lines)    
    problem_dict["solution"] = preprocess_solution(solution_lines)
    problem_dict["solution"] += " \(\\fbox{"+problem_dict["answer"]+"}\)." if problem_dict["answer"] else ""
    problem_dict["no_answer"] = (problem_dict["answer"]=="" or problem_dict["answer"] is None)
    return problem_dict


def extract_latex_content(latex_path):
    problem_list = []
    with open(f"{PROBLEMS_DIR}/ori/{latex_path}", 'r', encoding='utf-8') as file:
        content = file.read()

        contest_name, _ = os.path.splitext(latex_path)
        month, year, subtest = contest_name.split("_")[1:4]

        content_lines = content.splitlines()
        problem_begin = [idx for idx, line in enumerate(content_lines) if line.startswith('\\begin{enumerate}')] + [len(content_lines)-1,]
        
        for n in range(1, len(problem_begin)):
            problem_dict = {"year":year, "month":month, "subtest":subtest, "idx":n}
            problem_dict.update(parsing_problem(content_lines[problem_begin[n-1]:problem_begin[n]]))
            if problem_dict["subtest"]=="gen1":
                problem_dict["subtest"] = "gen"
            elif problem_dict["subtest"]=="gen2":
                problem_dict["subtest"] = "gen"
                problem_dict["idx"] += 10
            problem_list += [problem_dict,]

    return problem_list


def process_all_files():
    solution_list = sorted(os.listdir(f"{PROBLEMS_DIR}/ori"))
    with open(f"{PROBLEMS_DIR}/HMMT_problems.jsonl", 'w', encoding='utf-8') as w_file:
        for solution_file in solution_list:
            problem_list = extract_latex_content(solution_file)
            for problem_dict in problem_list:
                json_line = orjson.dumps(problem_dict, option=orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY)
                w_file.write(f"{str(json_line, encoding='utf-8')}\n")


# Run the processing function
process_all_files()
