# llmHMER/QwenHMER/dataset_utils.py

import json
import os
import re
from pathlib import Path
from typing import Any, Dict, List

import editdistance
import yaml

# Base counts
base_counts = {
    'crohme_2014': 986,
    'crohme_2016': 1147,
    'crohme_2019': 1199,
    'hme100k_test': 24607,
    "crohme_train":8834,
    "hme100k_train":74502,
}

# Test and validation specific counts
test_val_counts = {
    # 'crohme2023_CROHME2014_test': 986,
    # 'crohme2023_CROHME2023_test': 2300,
    # 'crohme2023_CROHME2019_test': 1199,
    # 'crohme2023_CROHME2016_test': 1147,
    # 'crohme2023_CROHME2023_val': 555,
    # 'crohme2023_train': 21038,
    "unimer_net_cpe": 5921,
    "unimer_net_hwe": 6332,
    "unimer_net_sce": 4742,
    "unimer_net_spe": 6762,
}

# Variants to generate
variants = ['', '_nobox', '_nospace', '_nospace_nobox','_nobox_white']

# Generate full dictionary
crohme_count_dict = {**base_counts, **test_val_counts}


# Add variants for each base CROHME dataset
for base_key in ['crohme_2014', 'crohme_2016', 'crohme_2019',"hme100k_test","hme100k_train","crohme_train"]:
    base_value = base_counts[base_key]
    for variant in variants:
        if variant:  # Skip empty variant as it's already in base_counts
            crohme_count_dict[f'{base_key}{variant}'] = base_value
            
# ast_variants = ['ast','tree']

# for base_key in ['crohme_2014','crohme_2016','crohme_2019','hme100k_test']:
#     base_value = crohme_count_dict[base_key]
#     for variant in ast_variants:
#         if variant:
#             crohme_count_dict[f'1shot-{variant}_nobox_white_{base_key}'] = base_value

# prefix_keys = ['1shot-bt-ast','1shot-bt-ast_no11']

# for prefix_key in prefix_keys:
#     for base_key in ['crohme_2014','crohme_2016','crohme_2019','hme100k_test']:
#         base_value = crohme_count_dict[base_key]
#         crohme_count_dict[f'{prefix_key}_{base_key}'] = base_value

# crohme_count_dict = {
#     'crohme_2014': 986,
#     'crohme_2016': 1147,
#     'crohme_2019': 1199,
#     'hme100k_test': 24607,
#     'test/CROHME2023_test': 2300,
#     'test/CROHME2019_test': 1199,
#     'val/CROHME2016_test': 1147,
#     'val/CROHME2023_val': 555,
#     'crohme_2014_nobox': 986,
#     'crohme_2014_nospace': 986,
#     'crohme_2014_nospace_nobox': 986,
#     'crohme_2016_nobox': 1147,
#     'crohme_2016_nospace': 1147,
#     'crohme_2016_nospace_nobox': 1147,
#     'crohme_2019_nobox': 1199,
#     'crohme_2019_nospace': 1199,
#     'crohme_2019_nospace_nobox': 1199
# }
# for year in ['2014', '2016', '2019']:
#     crohme_count_dict[f'crohme_{year}_nobox'] = crohme_count_dict[f'crohme_{year}']
#     crohme_count_dict[f'crohme_{year}_nospace'] = crohme_count_dict[f'crohme_{year}']
#     crohme_count_dict[f'crohme_{year}_nospace_nobox'] = crohme_count_dict[f'crohme_{year}']
#
# print(crohme_count_dict)




#######################################################################
# 2) 读写 CROHME 数据集
#######################################################################

def read_crohme_captions(year: str, base_dir: str) -> tuple[Dict[str, str], List[str]]:
    """
    read the caption.txt of the CROHME dataset of the specified year (or 'train')
    
    Args:
        year: '2014'/'2016'/'2019'/'train'
        base_dir: root directory of the dataset
        
    Returns:
        tuple: (captions_dict, file_list)
            - captions_dict: {filename: caption, ...} dictionary
            - file_list: file name list
    """
    caption_file = os.path.join(base_dir, year, 'caption.txt')
    with open(caption_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    file_list = []
    captions = {}
    for line in lines:
        parts = line.strip().split()
        filename = parts[0]
        file_list.append(filename)
        caption = " ".join(parts[1:])
        captions[filename] = caption
    return captions, file_list




def get_all_crohme_captions(base_dir: str) -> Dict[str, str]:
    """
    read the caption.txt of the CROHME dataset of the specified year (or 'train')
    
    Args:
        base_dir: root directory of the dataset
        
    Returns:
        Dict[str, str]: merged {filename: caption, ...}
    """
    c2014, _ = read_crohme_captions('2014', base_dir)
    c2016, _ = read_crohme_captions('2016', base_dir)
    c2019, _ = read_crohme_captions('2019', base_dir)
    ctrain, _ = read_crohme_captions('train', base_dir)
    
    # merge all dictionaries
    result = {}
    result.update(c2014)
    result.update(c2016)
    result.update(c2019)
    result.update(ctrain)
    
    return result






#######################################################################
# 3) read and write hme100k dataset
#######################################################################

def read_hme100k_captions(split: str, base_dir: str) -> Dict[str, str]:
    """
    read the caption.txt of the hme100k dataset
    :param split: 'train' / 'test'
    :param base_dir: root directory of the dataset
    :return: {filename_no_ext: caption, ...}
    """
    caption_file = os.path.join(base_dir, split, 'caption.txt')
    with open(caption_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    captions = {}
    for line in lines:
        parts = line.strip().split()
        filename_no_ext = parts[0].split('.')[0]  # remove .jpg
        caption = " ".join(parts[1:])
        captions[filename_no_ext] = caption
    return captions

def read_all_hme100k_captions(base_dir: str) -> Dict[str, str]:
    """
    read the caption.txt of the hme100k dataset
    :param base_dir: root directory of the dataset
    :return: merged {filename_no_ext: caption, ...}
    """
    train_captions = read_hme100k_captions('train', base_dir)
    test_captions = read_hme100k_captions('test', base_dir)
    return {**train_captions, **test_captions}



def preprocess_caption_dic(dic:Dict[str, str]):
    for key, val in dic.items():
        val = " ".join(val.strip().split())
        dic[key] = val
    return dic

def get_all_captions(crohme_dir,hme100k_dir,dir_list: List[str] = []) -> Dict[str, Dict[str, str]]:
    """

    """
    hme100k_test_dic = read_hme100k_captions('test', hme100k_dir)
    hme100k_train_dic = read_hme100k_captions('train', hme100k_dir)
    crohme2014_dic,_ = read_crohme_captions('2014', crohme_dir)
    crohme2016_dic,_ = read_crohme_captions('2016', crohme_dir)
    crohme2019_dic,_ = read_crohme_captions('2019', crohme_dir)
    crohme_train_dic,_ = read_crohme_captions('train', crohme_dir)
    final_dic = {}

    final_dic['crohme_2014'] = crohme2014_dic
    final_dic['crohme_2016'] = crohme2016_dic
    final_dic['crohme_2019'] = crohme2019_dic
    final_dic['crohme_train'] = crohme_train_dic
    final_dic['hme100k_test'] = hme100k_test_dic
    final_dic['hme100k_train'] = hme100k_train_dic

    variants = ['', '_nobox', '_nospace', '_nospace_nobox', '_nobox_white']
    for base_key in ['crohme_2014', 'crohme_2016', 'crohme_2019', 'hme100k_test','crohme_train','hme100k_train']:
        base_value = final_dic[base_key]
        for variant in variants:
            if variant:
                final_dic[f'{base_key}{variant}'] = base_value
    ast_variants = ['ast','tree']
    for base_key in ['crohme_2014','crohme_2016','crohme_2019','hme100k_test']:
        base_value = final_dic[base_key]
        for variant in ast_variants:
            if variant:
                final_dic[f'1shot-{variant}_nobox_white_{base_key}'] = base_value
    prefix_keys = ['1shot-bt-ast','1shot-bt-ast_no11']
    for prefix_key in prefix_keys:
        for base_key in ['crohme_2014','crohme_2016','crohme_2019','hme100k_test']:
            base_value = final_dic[base_key]
            final_dic[f'{prefix_key}_{base_key}'] = base_value
            

    for folder in dir_list:
        _, dic = get_crohme_2023_ids_dict(folder)
        for key, value in dic.items():
            final_dic[key] = value
            # print(value)
    # print(final_dic.keys())
    # for key, value in final_dic.items():
    #     print(key,len(value))
    return final_dic

#######################################################################
    # 4) general
#######################################################################

def extract_inboxed_content(latex_str: str) -> str:
    """
    extract the content inside \boxed{} from the LaTeX string
    :param latex_str: LaTeX string
    :return: the content inside \boxed{}
    """
    match = re.search(r'\\boxed\{(.*)\}', latex_str)
    if not match:
        raise ValueError(f"No \\boxed{{}} found in: {latex_str}")
    return match.group(1)


def compute_edit_distance(str_a, str_b) -> int:
    """
    compute the edit distance between two inputs, support string and list
    
    Args:
        str_a: the first input, can be string or list of tokens
        str_b: the second input, can be string or list of tokens
        
    Returns:
        int: edit distance
        
    Examples:
        >>> compute_edit_distance("a b c", "a b d")
        1
        >>> compute_edit_distance(["a", "b", "c"], ["a", "b", "d"])
        1
    """
    # convert to list format (if the input is a string)
    if isinstance(str_a, str):
        str_a = str_a.strip().split()
    if isinstance(str_b, str):
        str_b = str_b.strip().split()
    
    # ensure the input is list type
    if not isinstance(str_a, list) or not isinstance(str_b, list):
        raise TypeError("Inputs must be strings or lists of tokens")
        
    return editdistance.eval(str_a, str_b)



import os
from pathlib import Path


def extract_log_files_from_sh(sh_file_path: str) -> list:
    """
    extract all log file paths from the .sh file
    :param sh_file_path: sh file path
    :return: list of log file paths
    """
    log_files = []
    
    # read the sh file
    with open(sh_file_path, 'r', encoding='utf-8') as file:
        for line in file:
            # check if it contains '>' character
            if '>' in line:
                # split by '>', the second part is the log file path
                parts = line.split('>')
                # remove the last file name of the path
                # log_file_path = os.path.dirname(log_file_path)
                # log_file_path = os.path.join(log_file_path, "trainer_log.jsonl")
                # print(str(log_file_)path)
                log_files.append(Path(log_file_path))
    
    return log_files

def check_logs_exist_from_sh(sh_file_path: str) -> bool:
    """
    check if all log files exist based on the sh file path
    :param sh_file_path: sh file path
    :return: if all log files exist, return True; otherwise return False
    """
    log_files = extract_log_files_from_sh(sh_file_path)
    print(log_files)
    dir_path = os.path.dirname(sh_file_path)
    if not os.path.exists(dir_path) or not os.path.exists(os.path.join(dir_path,"exp_summary_list.json")):
        print(f"Log file directory not found: {dir_path}")
        return False
    exp_data = json.load(open(os.path.join(dir_path,"exp_summary_list.json"),"r"))
    if len(exp_data) == 0:
        print(f"No experiment data found in: {dir_path}")
        return False
    # generator_jsonl_file = os.path.join(dir_path,"generated_predictions.jsonl")
    # if not os.path.exists(generator_jsonl_file) or len(open(generator_jsonl_file,"r").readlines()) == 0:
    #     print(f"Generator JSONL file not found: {generator_jsonl_file}")
    #     return False
    all_logs_exist = True
    for log_file in log_files:
        if not log_file.exists():
            # print(f"Log file not found: {log_file}")
            all_logs_exist = False
            break
        # check the content of the log file
        with open(log_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
            # print(lines)
            for line in lines:
                if "raise" in line and "error" in line.lower():
                    all_logs_exist = False
                    break
                # print(line)
    
    # output the status
    print(all_logs_exist,"all_logs_exist")
    return all_logs_exist





def read_json_file(file_path: str) -> list[dict]:
    """
    Read and parse JSON file.

    Args:
        file_path (str): Path to the JSON file

    Returns:
        List[dict]: Parsed JSON content

    Raises:
        FileNotFoundError: If file does not exist
        json.JSONDecodeError: If file is not valid JSON
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"File not found: {file_path}")
    except json.JSONDecodeError as e:
        # pass the necessary parameters correctly
        raise json.JSONDecodeError(f"Invalid JSON file: {file_path} - {e.msg}", e.doc, e.pos) from e


def get_last_image_names(file_path: str) -> list[str]:
    """
    Extract filenames (without extension) of the last image from each item in JSON data.

    Args:
        file_path (str): Path to the JSON file containing image information

    Returns:
        List[str]: List of filenames without extensions

    Example:
        >>> get_last_image_names("path/to/data.json")
        ['514_em_341', 'RIT_2014_248']
    """
    try:
        # Read JSON data
        data = read_json_file(file_path)

        # Process each item
        result = []
        for item in data:
            if not item.get("images"):
                continue

            # Get the last image path
            last_image = item["images"][-1]

            # Convert to Path object for reliable path manipulation
            image_path = Path(last_image)

            # Get filename without extension
            filename = image_path.stem

            result.append(filename)

        return result

    except Exception as e:
        raise Exception(f"Error processing file: {str(e)}")
    
# dataset path configuration
# can be overridden by environment variables or configuration files
import os

# default path
DEFAULT_DATASET_PATHS = {
    # CROHME dataset path
    'crohme_2014': '/home/user/workspaces/llmHMER/hmer_dataset/nobox_white/crohme_2014.json',
    'crohme_2016': '/home/user/workspaces/llmHMER/hmer_dataset/nobox_white/crohme_2016.json',
    'crohme_2019': '/home/user/workspaces/llmHMER/hmer_dataset/nobox_white/crohme_2019.json',
    'crohme_train': '/home/user/workspaces/llmHMER/hmer_dataset/nobox_white/crohme_train.json',
    
    # HME100K dataset path
    'hme100k_test': '/home/user/workspaces/llmHMER/hmer_dataset/nobox_white/hme100k_test.json',
    'hme100k_train': '/home/user/workspaces/llmHMER/hmer_dataset/nobox_white/hme100k_train.json',
    
    # LlamaFactory dataset information file
    'llamafactory_data': '/home/user/workspaces/llmHMER/LLaMA-Factory/data/dataset_info.json',
    
    # backup directory
    'backup_data_utils': './backup_data_utils',
}

# load the configuration from environment variables, override the default values
def load_dataset_paths_from_env():
    """load the dataset path configuration from environment variables"""
    paths = DEFAULT_DATASET_PATHS.copy()
    for key in paths.keys():
        env_var = f"LLMHMER_DATASET_{key.upper()}"
        if env_var in os.environ:
            paths[key] = os.environ[env_var]
    return paths

# load the configuration
DATASET_PATHS = load_dataset_paths_from_env()

# use the configuration path instead of hard-coded path
crohme_2014_llamafactory_data_path = DATASET_PATHS['crohme_2014']
crohme_2016_llamafactory_data_path = DATASET_PATHS['crohme_2016']
crohme_2019_llamafactory_data_path = DATASET_PATHS['crohme_2019']
hme100k_test_llamafactory_data_path = DATASET_PATHS['hme100k_test']
crohme_train_llamafactory_data_path = DATASET_PATHS['crohme_train']
hme100k_train_llamafactory_data_path = DATASET_PATHS['hme100k_train']
# crohme2023_folder = '/data/user/llmHMER/hmer_dataset/crohme2023_white'
# unimernet_fodler = '/data/user/llmHMER/hmer_dataset/unimernet'

# def get_all_ids() -> Dict[str, List[str]]:
#     """
#     extract all ids from the caption dictionary
#     :param captions: {filename: caption, ...}
#     :return: all ids
#     """
#     dic = {}
#     dic["crohme_2014"] = get_last_image_names(crohme_2014_llamafactory_data_path)
#     dic["crohme_2016"] = get_last_image_names(crohme_2016_llamafactory_data_path)
#     dic["crohme_2019"] = get_last_image_names(crohme_2019_llamafactory_data_path)
#     dic["hme100k_test"] = get_last_image_names(hme100k_test_llamafactory_data_path)
#     for year in ['2014', '2016', '2019']:
#         dic[f'crohme_{year}_nobox'] = dic[f'crohme_{year}']
#         dic[f'crohme_{year}_nospace'] = dic[f'crohme_{year}']
#         dic[f'crohme_{year}_nospace_nobox'] = dic[f'crohme_{year}']
#     return dic

def aprompt2imgcaption(dic):
    img_path = dic['images'][0]
    caption = dic['messages'][1]['value']
    caption = " ".join(caption.strip().split())
    img_name = os.path.basename(img_path).split(".")[0]
    return img_name, caption

def get_crohme_2023_ids_dict(folder):
    ret_dic = {}
    ret_ids = {}
    for prompt_file in os.listdir(folder):
        if not prompt_file.endswith(".json"):
            continue
        # if "train" in prompt_file:
            # continue
        this_dic = {}
        this_list = []
        data = json.load(open(os.path.join(folder,prompt_file),"r"))
        for item in data:
            img_name, caption = aprompt2imgcaption(item)
            this_dic[img_name] = caption
            this_list.append(img_name)
        ret_dic[prompt_file.split('.')[0]] = this_dic
        ret_ids[prompt_file.split('.')[0]] = this_list
        print('loaded', prompt_file,"which length is ",len(data))

    if len(ret_ids) == 0:
        print("no ids found in", folder)
        raise ValueError("no ids found in", folder)
    # print(ret_dic)
    return ret_ids,ret_dic

def get_all_ids() -> Dict[str, List[str]]:
    """
    extract all ids from the caption dictionary
    :return: all ids
    """
    # base dataset path mapping
    base_paths = {
        'crohme_2014': crohme_2014_llamafactory_data_path,
        'crohme_2016': crohme_2016_llamafactory_data_path,
        'crohme_2019': crohme_2019_llamafactory_data_path,
        'hme100k_test': hme100k_test_llamafactory_data_path,
        'crohme_train': crohme_train_llamafactory_data_path,
        'hme100k_train': hme100k_train_llamafactory_data_path,
    }

    # variant suffix
    variants = ['_nobox', '_nospace', '_nospace_nobox','_nobox_white']

    # crohmer2023_ids,_ = get_crohme_2023_ids_dict(crohme2023_folder)
    # unimernet_ids,_ = get_crohme_2023_ids_dict(unimernet_fodler)
    # initialize the result dictionary
    dic = {}
    
    # process the base dataset
    for base_name, path in base_paths.items():
        base_ids = get_last_image_names(path)
        print(f"load {base_name} , which length is {len(base_ids)}")
        dic[base_name] = base_ids
        # if the dataset is CROHME, add the variant
        if base_name.startswith('crohme_') or base_name.startswith('hme100k_'):
            for variant in variants:
                dic[f'{base_name}{variant}'] = base_ids
                
    # ast_variants = ['ast','tree']
    # for base_key in ['crohme_2014','crohme_2016','crohme_2019','hme100k_test']:
    #     base_value = dic[base_key]
    #     for variant in ast_variants:
    #         if variant:
    #             dic[f'1shot-{variant}_nobox_white_{base_key}'] = base_value
    # prefix_keys = ['1shot-bt-ast','1shot-bt-ast_no11']
    # for prefix_key in prefix_keys:
    #     for base_key in ['crohme_2014','crohme_2016','crohme_2019','hme100k_test']:
    #         base_value = dic[base_key]
    #         dic[f'{prefix_key}_{base_key}'] = base_value
                

    # extend dic to unimernet and 2023
    dic = {**dic}
    return dic


def sort_checkpoint_folders(folder_list):
    # define a function to extract the number after the checkpoint
    def extract_number(folder_name):
        try:
            # split by '-', and extract the number after the checkpoint
            return int(folder_name.split('-')[1])
        except (IndexError, ValueError):
            # if not the expected format, return -1 (such items will be placed at the front)
            return -1

    # use sorted function, key parameter uses extract_number function
    sorted_folders = sorted(folder_list, key=extract_number)
    return sorted_folders


def preprocess_llamafactory_data(folder_path: str) -> dict:
    """
    preprocess the data in the specified folder, generate the dataset format information required by LlamaFactory
    
    Args:
        folder_path: data folder path
        
    Returns:
        dict: dictionary like {key: dataset_info, ...}
    """
    result = {}
    for file in os.listdir(folder_path):
        if file.endswith(".json"):
            # build the full file path
            full_path = os.path.join(folder_path, file)
            # file name, without .json suffix
            key = f"{os.path.splitext(file)[0]}"
            
            # create the dataset information dictionary
            dataset_info = {
                "file_name": full_path,
                "formatting": "sharegpt",
                "columns": {
                    "messages": "messages",
                    "images": "images"
                },
                "tag": {
                    "role_tag": "from",
                    "content_tag": "value",
                    "user_tag": "human",
                    "assistant_tag": "gpt"
                }
            }
            
            result[key] = dataset_info
    
    return result


def update_dataset_counts(folder_list: list, count_dict: dict) -> dict:
    """
    update the dataset statistics information
    
    Args:
        folder_list: data folder list
        count_dict: existing statistics dictionary
        
    Returns:
        dict: updated statistics dictionary
    """
    updated_dict = count_dict.copy()
    
    for folder in folder_list:
        for file in os.listdir(folder):
            if not file.endswith('.json'):
                continue
                
            file_path = os.path.join(folder, file)
            try:
                this_data = json.load(open(file_path, "r"))
                dataset_key = file.split('.')[0]
                
                if dataset_key not in updated_dict:
                    updated_dict[dataset_key] = len(this_data)
                    print(f"Added count for {dataset_key}: {len(this_data)}")
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
    
    return updated_dict


def update_test_ids(folder_list: list, id_dict: dict) -> dict:
    """
    update the test ID dictionary
    
    Args:
        folder_list: data folder list
        id_dict: existing ID dictionary
        
    Returns:
        dict: updated ID dictionary
    """
    updated_dict = id_dict.copy()
    
    for folder in folder_list:
        try:
            this_ids, _ = get_crohme_2023_ids_dict(folder)
            for key, value in this_ids.items():
                if key not in updated_dict:
                    updated_dict[key] = value
                    print(f"Added IDs for {key}: {len(value)} items")
        except Exception as e:
            print(f"Error processing folder {folder}: {e}")
    
    return updated_dict


def update_captions(folder_list: list, caption_dict: dict) -> dict:
    """
    update the caption dictionary
    
    Args:
        folder_list: data folder list
        caption_dict: existing caption dictionary
        
    Returns:
        dict: updated caption dictionary
    """
    updated_dict = caption_dict.copy()
    
    for folder in folder_list:
        try:
            _, this_captions = get_crohme_2023_ids_dict(folder)
            for key, value in this_captions.items():
                if key not in updated_dict:
                    updated_dict[key] = value
                    print(f"Added captions for {key}")
        except Exception as e:
            print(f"Error processing folder {folder} captions: {e}")
    
    return updated_dict


def register_llamafactory_data():
    """register the dataset information required by LlamaFactory"""
    global crohme_count_dict, all_test_id, all_caption_dict
    
    # data folder list
    folder_list = [
        "/home/user/workspaces/llmHMER/hmer_dataset/nobox_white", 
        "/home/user/workspaces/llmHMER/notebook/split/hme100k_3_fold_data",
        "/home/user/workspaces/llmHMER/hmer_dataset/im2latex_v2_100k",
        "/home/user/workspaces/llmHMER/hmer_dataset/test/0315_2fix_error_methods",
        "/home/user/workspaces/llmHMER/hmer_dataset/crohme_2023_final",
        "/home/user/workspaces/llmHMER/hmer_dataset/test/0318_refine_fixandfix",
        "/home/user/workspaces/llmHMER/hmer_dataset/test/0319_new_data",
        "/home/user/workspaces/llmHMER/hmer_dataset/test/0325_crohme2023_tdv2_2",
        "/home/user/workspaces/llmHMER/hmer_dataset/mathwritting",
        "/home/user/workspaces/llmHMER/hmer_dataset/test/0410_split",
        "/home/user/workspaces/llmHMER/hmer_dataset/0410final",
        "/home/user/workspaces/llmHMER/hmer_dataset/0410final/fix",
        "/home/user/workspaces/llmHMER/hmer_dataset/0410final/bttr",
        "/home/user/workspaces/llmHMER/hmer_dataset/test/unimernet",
        "/home/user/workspaces/llmHMER/hmer_dataset/0410final/can",
        "/home/user/workspaces/llmHMER/hmer_dataset/test/0415_original_hme100k",
        "/home/user/workspaces/llmHMER/hmer_dataset/0410final/error_data",
        # "/home/user/workspaces/data-user/llmHMER/hmer_dataset/test/0423_split",
        "/home/user/workspaces/llmHMER/hmer_dataset/test/0425_split",
        "/home/user/workspaces/llmHMER/hmer_dataset/0425final",
        "/home/user/workspaces/llmHMER/hmer_dataset/0425final/can",
        "/home/user/workspaces/llmHMER/hmer_dataset/0425final/error_data",
        "/home/user/workspaces/llmHMER/hmer_dataset/test/0428_split"
    ]
    
    # folders that do not need to be registered
    tobe_delete_folder_list = [
        "/home/user/workspaces/llmHMER/hmer_dataset/test/0408_split",
        "/home/user/workspaces/llmHMER/hmer_dataset/0407final",
        "/home/user/workspaces/llmHMER/hmer_dataset/0407final/bttr",
        "/home/user/workspaces/llmHMER/hmer_dataset/0407final/can",
        "/home/user/workspaces/llmHMER/hmer_dataset/archive/nospace",
        # "/home/user/workspaces/llmHMER/hmer_dataset/unimernet"
    ]
    
    # backup path
    backup_data_utils_path = DATASET_PATHS['backup_data_utils']

    backup_folder_list_path = os.path.join(backup_data_utils_path, "folder_list.json")
    
    # if the backup folder list does not exist, create an empty list, and reload if the data folder list changes
    backup_folder_list = [] if not os.path.exists(backup_folder_list_path) else json.load(open(backup_folder_list_path, "r"))
        
    # check if need to load new data
    need_load_new_data = not os.path.exists(backup_data_utils_path) or len(os.listdir(backup_data_utils_path)) == 0 or backup_folder_list != folder_list
    if need_load_new_data:
        os.makedirs(backup_data_utils_path, exist_ok=True)
    
    # register the dataset to LlamaFactory
    llamafactory_data_path = DATASET_PATHS['llamafactory_data']
    llama_factory_data = json.load(open(llamafactory_data_path, "r"))
    
    # process the dataset of each folder
    for folder in folder_list:
        process_data = preprocess_llamafactory_data(folder)
        print(f"process the folder: {folder}, found dataset: {len(process_data.keys())}")
        
        # update the LlamaFactory dataset information
        for key, value in process_data.items():
            if key not in llama_factory_data:
                llama_factory_data[key] = value
                print(f"register new dataset: {key}")
                need_load_new_data = True
        for tobe_delete_folder in tobe_delete_folder_list:
            if tobe_delete_folder in llama_factory_data:
                del llama_factory_data[tobe_delete_folder]
                print(f"delete dataset: {tobe_delete_folder}")
    
    # save the LlamaFactory dataset information
    with open(llamafactory_data_path, "w") as f:
        json.dump(llama_factory_data, f, indent=2)
    
    # if need to load new data
    if need_load_new_data:
        # update the dataset count
        crohme_count_dict = update_dataset_counts(folder_list, crohme_count_dict)
        
        # update the test ID
        all_test_id = update_test_ids(folder_list, all_test_id)
        print(f"updated test ID count: {len(all_test_id.keys())}")
        
        # filter out the folders that do not need to be registered
        for folder in tobe_delete_folder_list:
            if folder in folder_list:
                folder_list.remove(folder)
        
        # update the caption dictionary
        all_caption_dict = update_captions(folder_list, all_caption_dict)
        print(f"updated caption dictionary count: {len(all_caption_dict.keys())}")
        
        # save the updated data
        with open(os.path.join(backup_data_utils_path, "all_test_id.json"), "w") as f:
            json.dump(all_test_id, f, indent=2)
        with open(os.path.join(backup_data_utils_path, "all_caption_dict.json"), "w") as f:
            json.dump(all_caption_dict, f, indent=2)
        with open(os.path.join(backup_data_utils_path, "crohme_count_dict.json"), "w") as f:
            json.dump(crohme_count_dict, f, indent=2)
        # save folder_list
        with open(os.path.join(backup_data_utils_path, "folder_list.json"), "w") as f:
            json.dump(folder_list, f, indent=2)
    else:
        # load the existing data directly
        print("no need to load new data, load from backup")
        crohme_count_dict = json.load(open(os.path.join(backup_data_utils_path, "crohme_count_dict.json"), "r"))
        all_test_id = json.load(open(os.path.join(backup_data_utils_path, "all_test_id.json"), "r"))
        all_caption_dict = json.load(open(os.path.join(backup_data_utils_path, "all_caption_dict.json"), "r"))

all_test_id = get_all_ids()
all_caption_dict = get_all_captions(
    crohme_dir="/home/user/workspaces/HMER_Dataset/crohme-rgb-white/HMER/CROHME",
    hme100k_dir="/home/user/workspaces/HMER_Dataset/hme100k"
)
register_llamafactory_data()
if __name__ == "__main__":
    # dic = get_all_ids()
    # print(dic.keys())
    # for key, value in dic.items():
    #     print(key,len(value))
    print(crohme_count_dict)
    # register_llamafactory_data()
    print(crohme_count_dict)
    
    
    # all_caption_dict = get_all_captions(crohme_dir="/home/user/workspaces/HMER_Dataset/crohme-rgb/HMER/CROHME",
    #                                     hme100k_dir="/home/user/workspaces/HMER_Dataset/hme100k",
    #                                     dir_list=[crohme2023_folder, unimernet_fodler])
    
    # print(data.keys())
    # for key, value in data.items():
    #     print(key,value)
        
    # initialize the global variables
    

    # ensure the registration of data
    register_llamafactory_data()
