from datasets import load_dataset, DatasetDict, Dataset
import os, orjson
# from tqdm.auto import tqdm
from bs4 import BeautifulSoup

LOAD_PATH = "/fs/cml-projects/E2H/Codeforces/contest_problem/contest_problem_html"
SAVE_PATH = "/fs/cml-projects/E2H/Codeforces/contest_problem/contest_problem_json"

def get_testcases(module):
    input = module.find_all(name="div", class_="input")
    output = module.find_all(name="div", class_="output")
    assert len(input)==len(output), "I/O dismatch!"
    input = [input_n.find_all(name="pre")[0].get_text(separator='\n', strip=True) for input_n in input]
    output = [output_n.find_all(name="pre")[0].get_text(separator='\n', strip=True) for output_n in output]
    return  {"input":input, "output":output}


problem_dataset = load_dataset("mcding-org/Easy2Hard-Codeforces", 'problem-v1', cache_dir="/fs/cml-projects/E2H/Huggingface_cache")

for example in problem_dataset["train"].__iter__():
    contestId = example["contestId"]
    index = example["index"]
    html_name = f"contestID_{contestId}_index_{index}.html"
    with open(f"{LOAD_PATH}/contestID_{contestId}_index_{index}.html", "r") as f:
        try:
            problem_statement = BeautifulSoup(f.read(), features="html.parser").find_all(name="div", class_="problem-statement")[0]
            header = problem_statement.find_all(name="div", class_="header")[0]
        except:
            print(f"Fail to get problem: {html_name}")
            continue

        module = header.find_next_sibling()
        json_dict = {"main":[]}
        while module:
            class_name = module.get("class", [None])[0]
            if class_name=="sample-tests":
                json_dict[class_name] = get_testcases(module)
            elif class_name in ["input-specification", "output-specification", "note"]:
                subtitle = module.findChildren(name="div", class_="section-title")[0]
                subtitle.decompose()
                json_dict[class_name] = module.get_text()
            else:
                json_dict["main"] += [module.get_text(),]

            module = module.find_next_sibling()
    
        with open(f"{SAVE_PATH}/contestId_{contestId}_index_{index}.json", "w") as wf:
            json_line = orjson.dumps(json_dict, option=orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY)
            wf.write(f"{str(json_line, encoding='utf-8')}")