# llmHMER/QwenHMER/inference_runner.py

import json
import os
from typing import Dict, List

from dataset_utils import compute_edit_distance, extract_inboxed_content
from tqdm import tqdm


def cal_CER(pred,gt):
    pred = pred.strip().split()
    gt = gt.strip().split()
    distance = compute_edit_distance(pred,gt)
    return distance/len(gt)


def cal_whole_CER(data):
    """
    calculate the character error rate of the whole dataset
    
    Args:
        data: a list of data, each item should have 'pred' and 'gt' keys
        
    Returns:
        float: the character error rate of the whole dataset
    """
    if not data:
        return 1.0  # empty data returns 100% error rate
        
    total_distance = 0
    total_length = 0
    
    for line in data:
        pred = line['pred']
        gt = line['gt']
        distance = compute_edit_distance(pred,gt)
        total_distance += distance
        total_length += len(gt)
    return total_distance/total_length

def split_predictions_by_dataset(
        prediction_file: str,
        eval_datasets: List[str],
        crohme_count_dict: Dict[str, int],
        output_dir: str
) -> None:
    """
    split the unified generated_predictions.jsonl into multiple sub-dataset files based on eval_datasets and crohme_count_dict

    :param prediction_file: the unified prediction result file path
    :param eval_datasets:   the list of datasets to be split, in order
    :param crohme_count_dict: a dictionary of the number of predictions for each dataset
    :param output_dir:      the output directory of the split files
    """
    # read all predictions
    with open(prediction_file, 'r', encoding='utf-8') as f:
        predictions = [json.loads(line) for line in f.readlines()]

    total_required = 0
    dataset_counts = {}
    for dataset in eval_datasets:
        count = crohme_count_dict.get(dataset, 0)  # directly use the full dataset name as the key
        if count == 0:
            print(f"Warning: Dataset {dataset} not found in crohme_count_dict. Skipping.")
            continue
        dataset_counts[dataset] = count
        total_required += count

    if total_required != len(predictions):
        # output the current dataset information and prediction file information
        print(f"Dataset counts: {dataset_counts}")
        print(f"prediction_file: {prediction_file}")
        len_predictions = open(prediction_file,'r',encoding='utf-8').readlines()
        print(f"len_predictions: {len(len_predictions)}")
        raise ValueError(
            f"Total predictions ({len(predictions)}) does not match the sum of dataset counts ({total_required})."
        )

    # split predictions
    current_index = 0
    for dataset, count in dataset_counts.items():
        dataset_predictions = predictions[current_index:current_index + count]
        current_index += count
        output_file = os.path.join(output_dir, f"{dataset}_predictions.jsonl")
        with open(output_file, 'w', encoding='utf-8') as f_out:
            for pred in dataset_predictions:
                f_out.write(json.dumps(pred, ensure_ascii=False) + '\n')
        print(f"Saved {count} predictions to {output_file}")

    # verify the total number after splitting
    assert current_index == len(predictions), "Mismatch in total predictions after splitting."
    print("Successfully split predictions into datasets.")

def run_inference_on_dataset(
    prediction_file: str,
    image_id_list: List[str],
    gt_captions_dict: Dict[str, str],
    dataset_name: str,
    crohme_count_dict: Dict[str, int],
    output_prefix: str = "",
    output_dir: str = "./",
    is_nobox = True,
):
    """
    analyze and save the prediction results of a single dataset
    
    Args:
        prediction_file: .jsonl file, each line contains { "label":..., "predict":... }
        image_id_list: the list of image IDs for this dataset, in order, aligned with the lines in prediction_file
        gt_captions_dict: {img_id: gt_caption, ...}
        dataset_name: the name of the dataset (must be found in crohme_count_dict to calculate accuracy)
        crohme_count_dict: a dictionary of statistics, used to calculate accuracy
        output_prefix: the prefix of the output file
        output_dir: the output directory
        is_nobox: whether to exclude \boxed{} tags
    
    Returns:
        Dict: a dictionary containing accuracy statistics
    """
    data_list = []
    exp_count = {0: 0, 1: 0, 2: 0}
    error_count = 0  # initialize the error counter

    try:
        with open(prediction_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
    except Exception as e:
        print(f"Error reading prediction file {prediction_file}: {e}")
        raise e
        
    print(f"[{dataset_name}] Total lines in result file: {len(lines)}")

    if len(lines) == 0:
        print(f"Warning: Empty prediction file {prediction_file}")
        raise ValueError(f"Empty prediction file {prediction_file}")
        
    if len(lines) > len(image_id_list):
        raise ValueError(f"More predictions ({len(lines)}) than image IDs ({len(image_id_list)})")
    # create image_id_list set for fast lookup
    image_id_set = set(image_id_list)
    gt_caption_keys = set(gt_captions_dict.keys())

    for idx, line in enumerate(tqdm(lines, desc=f"Processing {dataset_name}")):
        try:
            if idx >= len(image_id_list):
                print(f"Warning: Index {idx} exceeds image_id_list length ({len(image_id_list)})")
                raise ValueError(f"Index {idx} exceeds image_id_list length ({len(image_id_list)})")
                
            img_id = image_id_list[idx]
            if img_id not in image_id_set:
                print(f"Warning: Image ID {img_id} not in image_id_list")
                raise ValueError(f"Image ID {img_id} not in image_id_list")
                
            if img_id not in gt_caption_keys:
                print(f"Warning: Image ID {img_id} not found in gt_captions_dict")
                raise ValueError(f"Image ID {img_id} not found in gt_captions_dict")
            
            # parse the prediction result
            line_data = json.loads(line)
            label_raw = line_data['label']
            
            # compatible with different prediction result formats
            if "preds" in line_data:
                pred_raw = line_data['preds'][0]
            else:
                pred_raw = line_data['predict']
                
            save_raw = {
                'label': label_raw,
                'pred': pred_raw
            }
            # special handling for ast or tree type datasets
            if "ast" in dataset_name or "tree" in dataset_name or "can" in dataset_name or "bttr" in dataset_name or "error" in dataset_name:
                label_raw = label_raw.strip().split("\n")[-1]
                pred_raw = pred_raw.strip().split("\n")[-1]
                
            # standardize spaces
            label_raw = " ".join(label_raw.strip().split())
            pred_raw = " ".join(pred_raw.strip().split())
            label_clean = label_raw
            pred_clean = pred_raw
            
            # extract
            if not is_nobox:
                try:
                    label_clean = extract_inboxed_content(label_raw)
                    pred_clean = extract_inboxed_content(pred_raw)
                except ValueError:
                    # if \boxed{} is not matched, use the original label
                    pred_clean = pred_raw
                    try:
                        label_clean = extract_inboxed_content(label_raw)
                    except ValueError:
                        label_clean = label_raw

            # calculate the edit distance
            distance = compute_edit_distance(label_clean.split(), pred_clean.split())
            for i in range(3):
                if distance <= i:
                    exp_count[i] += 1

            data_list.append({
                "img_id": img_id,
                "gt": label_clean,
                "pred": pred_clean,
                "distance": distance,
                "raw_gt": save_raw['label'],
                "raw_pred": save_raw['pred'],
            })

        except json.JSONDecodeError as e:
            print(f"Error parsing JSON at line {idx+1}: {e}")
            error_count += 1
            raise e
        except Exception as e:
            print(f"Error processing line {idx+1}: {e}")
            error_count += 1
            raise e
        
    if error_count > 0:
        print(f"Total of {error_count} errors occurred during processing")

    # ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # handle the case of no data
    if not data_list:
        print(f"Warning: No valid data processed for {dataset_name}")
        return {
            "dataset_name": dataset_name,
            "total_count": 0,
            "ed0_count": 0,
            "ed1_count": 0,
            "ed2_count": 0,
            "acc0": 0,  # <=0
            "acc1": 0,  # <=1
            "acc2": 0,  # <=2
            "whole_cer": 1.0,
            "checkpoint_step": 0
        }
        
    # save the processed results
    out_file = os.path.join(output_dir, f"{output_prefix}result_{dataset_name}.json")
    with open(out_file, 'w', encoding='utf-8') as f:
        json.dump(data_list, f, indent=4, ensure_ascii=False)
        
    # save good case and bad case
    good_case_file = os.path.join(output_dir, f"{output_prefix}good_case_{dataset_name}.json")
    bad_case_file = os.path.join(output_dir, f"{output_prefix}bad_case_{dataset_name}.json")
    good_case_list = []
    bad_case_list = []
    for data in data_list:
        if data['distance'] <= 0:
            good_case_list.append(data)
        else:
            bad_case_list.append(data)
    with open(good_case_file, 'w', encoding='utf-8') as f:
        json.dump(good_case_list, f, indent=4, ensure_ascii=False)
    with open(bad_case_file, 'w', encoding='utf-8') as f:
        json.dump(bad_case_list, f, indent=4, ensure_ascii=False)

    # calculate and print the accuracy
    total_count = len(data_list)  # use the processed data by default
    exp_results = {}
    
    # if the dataset is in the statistics dictionary, use the number in the statistics dictionary
    if dataset_name in crohme_count_dict:
        total_count = crohme_count_dict[dataset_name]
        
    exp_results = {i: round(100.0 * exp_count[i] / total_count, 3) if total_count > 0 else 0 for i in range(3)}
    print(f"[{dataset_name}] Accuracy results: {exp_results}")
    
    whole_cer = cal_whole_CER(data_list) if data_list else 1.0
    print(f"[{dataset_name}] Whole CER: {whole_cer}")
    
    # write the results to file
    try:
        with open(os.path.join(output_dir, f"{output_prefix}result_{dataset_name}_accuracy.txt"), 'w', encoding='utf-8') as f:
            f.write(f"Accuracy results for {dataset_name}:\n")
            for i in range(3):
                f.write(f"Edit distance <= {i}: {exp_count[i]} / {total_count} ({exp_results[i]}%)\n")
            f.write("\n")
            f.write(f"Total count: {total_count}\n")
            f.write(f"Total correct: {exp_count[0]}\n")
            f.write(f"Total correct (<=1): {exp_count[1]}\n")
            f.write(f"Total correct (<=2): {exp_count[2]}\n")
            f.write(f"Whole CER: {whole_cer}\n")
            f.write("\n")
    except Exception as e:
        print(f"Error writing accuracy results: {e}")
    
    # get the checkpoint step information
    try:
        checkpoint_folders = [f for f in prediction_file.split("/") if f.startswith("checkpoint")]
        if checkpoint_folders:
            step_str = checkpoint_folders[0].replace("checkpoint-", "")
            step = int(step_str)
        else:
            step = 0
    except (ValueError, IndexError):
        print(f"Warning: Could not extract step number from {prediction_file}")
        step = 0
        
    # return the statistics results
    return {
        "dataset_name": dataset_name,
        "total_count": total_count,
        "ed0_count": exp_count[0],
        "ed1_count": exp_count[1],
        "ed2_count": exp_count[2],
        "acc0": exp_results.get(0, 0),  # <=0
        "acc1": exp_results.get(1, 0),  # <=1
        "acc2": exp_results.get(2, 0),  # <=2
        "whole_cer": whole_cer,
        "checkpoint_step": step
    }


# llmHMER/QwenHMER/inference_runner.py

def aggregate_inference_results(
        dataset_split_file: str,
        output_folder: str,
        dataset_name: str,
        image_id_list: List[str],
        gt_captions_dict: Dict[str, str],
        crohme_count_dict: Dict[str, int]
):
    """
    process the split sub-dataset files, call run_inference_on_dataset for evaluation

    :param dataset_split_file: the path of the split sub-dataset prediction file
    :param output_folder:      the output directory of the evaluation results
    :param dataset_name:       the name of the dataset
    :param image_id_list:      the list of img_id for the current dataset
    :param gt_captions_dict:   the dictionary of ground truth caption for the current dataset
    :param crohme_count_dict:  the statistics dictionary
    """
    os.makedirs(output_folder, exist_ok=True)
    # print(gt_captions_dict.keys())
    # automatically get the date of the current run
    # current_date = datetime.now().strftime('%m%d')
    prefix = ""  # date format
    exp_stats = run_inference_on_dataset(
        prediction_file=dataset_split_file,
        image_id_list=image_id_list,
        gt_captions_dict=gt_captions_dict,
        dataset_name=dataset_name,
        crohme_count_dict=crohme_count_dict,
        output_prefix=prefix,
        output_dir=output_folder
    )
    print(f"Aggregated inference results for {dataset_name} saved to {output_folder}")
    # return exp_stats, for the main function to aggregate
    return exp_stats

def main():
    from dataset_utils import register_llamafactory_data

    # register_all_datasets()
    register_llamafactory_data()
    
    from dataset_utils import all_caption_dict, all_test_id, crohme_count_dict



    for model_name in ["qwen2.5_vl-3b","qwen2.5_vl-7b"]:
        # ,"qwen2_vl-2b","qwen2_vl-7b"
        zeroshot_dir = f"/home/user/workspaces/llmHMER/LLaMA-Factory/saves/{model_name}/full/sft/predict/zeroshot"
        predict_file = os.path.join(zeroshot_dir,"generated_predictions.jsonl")
        print(predict_file)
        if not os.path.exists(predict_file):
            print(f"Skipping {model_name}, no generated_predictions.jsonl found.")
            continue

        eval_datasets = ["crohme2023_CROHME2014_test","crohme2023_CROHME2016_test","crohme2023_CROHME2019_test","crohme2023_CROHME2023_test","crohme_2014_nobox_white","crohme_2016_nobox_white","crohme_2019_nobox_white","hme100k_test_nobox_white","hme100k_train_nobox_white","crohme_train_nobox_white"]
        split_predictions_by_dataset(
            prediction_file=predict_file,
            eval_datasets=eval_datasets,
            crohme_count_dict=crohme_count_dict,
            output_dir=zeroshot_dir  # save the split files in the respective checkpoint folders
        )
        all_id_list = all_test_id
        for dataset in eval_datasets:
            data_list = []
            with open(os.path.join(zeroshot_dir,f"{dataset}_predictions.jsonl"), 'r', encoding='utf-8') as f:
                lines = f.readlines()
            for idx, line in enumerate(tqdm(lines, desc=f"Processing {dataset}")):
                line_data = json.loads(line)
                label_raw = line_data['label']
                pred_raw = line_data['predict']
                img_id = all_id_list[f'{dataset}'][idx]
            
                data_list.append({
                    "img_id": img_id,
                    "gt": label_raw,
                    "pred": pred_raw
                })
            out_file = os.path.join("/home/user/workspaces/llmHMER/hmer_dataset/zeroshot","original",f"{model_name}_{dataset}.json")
            with open(out_file, 'w', encoding='utf-8') as f:
                json.dump(data_list, f, indent=4, ensure_ascii=False)
        
if __name__ == "__main__":
    main()
