# llmHMER/QwenHMER/cdm_evaluator.py

import json
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from config_utils import  run_command
warnings.filterwarnings('ignore')

# cdm_evaluator.py

def plot_exp_metrics(exp_summary_list: list, output_dir: str = 'exp_output', plot_filename: str = 'exp_metrics_plot.png'):
    """
    plot the edit distance related metrics (e.g. acc0, acc1, acc2, etc.)
    exp_summary_list example:
    {
        "dataset_name": xxx,
        "total_count": xxx,
        "ed0_count": xxx,
        "ed1_count": xxx,
        "ed2_count": xxx,
        "acc0": xxx,
        "acc1": xxx,
        "acc2": xxx,
        "checkpoint_step": 5,  # we need to add this in the main function
    }
    """

    os.makedirs(output_dir, exist_ok=True)

    # organize data
    data = {}
    for entry in exp_summary_list:
        dataset = entry['dataset_name']
        step = entry['checkpoint_step']
        if dataset not in data:
            data[dataset] = {"steps": [], "acc0": [], "acc1": [], "acc2": []}
        data[dataset]["steps"].append(step)
        data[dataset]["acc0"].append(entry["acc0"])
        data[dataset]["acc1"].append(entry["acc1"])
        data[dataset]["acc2"].append(entry["acc2"])

    # plot
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(21, 7))
    colors = plt.cm.tab10.colors

    for idx, (dataset, metrics) in enumerate(sorted(data.items())):
        steps = np.array(metrics["steps"])
        acc0 = np.array(metrics["acc0"])
        acc1 = np.array(metrics["acc1"])
        acc2 = np.array(metrics["acc2"])

        # sort by step, ensure the correctness of the curve
        sorted_indices = np.argsort(steps)
        steps = steps[sorted_indices]
        acc0 = acc0[sorted_indices]
        acc1 = acc1[sorted_indices]
        acc2 = acc2[sorted_indices]

        color = colors[idx % len(colors)]
        ax1.plot(steps, acc0, marker='o', label=dataset, color=color)
        ax2.plot(steps, acc1, marker='s', label=dataset, color=color)
        ax3.plot(steps, acc2, marker='^', label=dataset, color=color)

    ax1.set_title('Accuracy (Edit Distance <= 0)')
    ax1.set_xlabel('Checkpoint Step')
    ax1.set_ylabel('Accuracy %')
    ax1.grid(True)
    ax1.legend()

    ax2.set_title('Accuracy (Edit Distance <= 1)')
    ax2.set_xlabel('Checkpoint Step')
    ax2.set_ylabel('Accuracy %')
    ax2.grid(True)
    ax2.legend()

    ax3.set_title('Accuracy (Edit Distance <= 2)')
    ax3.set_xlabel('Checkpoint Step')
    ax3.set_ylabel('Accuracy %')
    ax3.grid(True)
    ax3.legend()

    plt.tight_layout()
    plot_path = os.path.join(output_dir, plot_filename)
    plt.savefig(plot_path)
    plt.close()
    print(f"[EXP] Edit Distance metrics plot saved to {plot_path}")



def cdm_proc_folder(input_folder: str, output_dir: str = 'cdm_output', summary_file: str = 'cdm_summary.json',dataset_name: str = "",cp_output_path="") -> list:
    """
    call /home/user/githubrepo/UniMERNet/cdm/evaluation.py for evaluation, and record the metadata of the evaluation results

    :param input_folder: folder containing several *.json inference result files
    :param output_dir:   cdm output evaluation result directory
    :param summary_file: JSON file name to save the evaluation metadata
    """
    os.makedirs(output_dir, exist_ok=True)
    summary_path = os.path.join(output_dir, summary_file)
    evaluation_summary = []

    # traverse all .json files in the folder
    # if there is no json, raise an error
    if not any(filename.endswith('.json') for filename in os.listdir(input_folder)):
        raise FileNotFoundError("No JSON files found in the input folder")

    for filename in tqdm(os.listdir(input_folder), desc="CDM Evaluation"):
        if not filename.endswith('.json'):
            continue
        if "good" in filename or 'bad' in filename or "correct" in filename:
            continue
        

        print("Processing file:", filename)
        in_file = os.path.join(input_folder, filename)
        base_name = os.path.splitext(filename)[0]

        # output path
        # this_output = os.path.join(output_dir, f"{base_name}_cdm")
        # os.makedirs(this_output, exist_ok=True)
        this_output = output_dir

        # assemble command line
        cpu_cores = os.cpu_count()
        used_cores = max(cpu_cores // 2, 1)
        cmd = (
            f"python /home/user/githubrepo/UniMERNet/cdm/evaluation.py "
            f"-i {in_file} "
            f"-o {this_output} "
            f"-p {used_cores}"
        )
        print("[CDM Eval] CMD:", cmd)
        os.system(cmd)
        # run_command(cmd)
        # read
        metrics_res_path = os.path.join(
            this_output,
            base_name,
            "metrics_res.json"
        )

        if os.path.exists(metrics_res_path):
            with open(metrics_res_path, 'r', encoding='utf-8') as f:
                metric_dic = json.load(f)
            print("mean_score: ", metric_dic.get("mean_score", "N/A"))
            print("exp_rate: ", metric_dic.get("exp_rate", "N/A"))
            print("invalid_list count: ", len(metric_dic.get("invalid_list", [])))
            print("invalid_list: ", metric_dic.get("invalid_list", []))
            # print("correct_list: ", metric_dic.get("correct_list", []))

            if cp_output_path != "":
                # copy the result to cp_output_path
                os.makedirs(cp_output_path, exist_ok=True)
                run_command(f"cp  {metrics_res_path} {cp_output_path}")
                metric_dic["cdm_output_path"] = output_dir
                with open(metrics_res_path, 'w', encoding='utf-8') as f:
                    json.dump(metric_dic, f, ensure_ascii=False, indent=4)
            # extract step and dataset_name
            # assume the filename format is "ckp-15_crohme_2014.json"
            # parts = base_name.split('_')
            # step = None
            # dataset_name = None
            # the last folder name is checkpoint-15, use the path of input_folder to extract step
            try:
                step = int(os.path.basename(input_folder).replace("checkpoint-", ""))
            except:
                step = 0

            correct_list = metric_dic.get("correct_list", [])
            with open(os.path.join(input_folder, "correct_list.json"), "w", encoding="utf-8") as f:
                json.dump(correct_list, f, ensure_ascii=False, indent=4)

            if step is None:
                step = 0  # default value
            if dataset_name is None or dataset_name == "":
                dataset_name = "unknown"

            # record the evaluation result
            evaluation_entry = {
                "dataset_name": dataset_name,
                "checkpoint_step": step,
                "metrics_res_path": metrics_res_path,
                "mean_score": metric_dic.get("mean_score", 0),
                "exp_rate": metric_dic.get("exp_rate", 0),
                "invalid_count": len(metric_dic.get("invalid_list", [])),
            }
            evaluation_summary.append(evaluation_entry)
            # return evaluation_summary
        else:
            error_msg = f"metrics_res.json not found, path: {metrics_res_path}"
            print(error_msg)
            raise FileNotFoundError(error_msg)

    for filename in os.listdir(input_folder):
        if "correct_list.json" in filename:
            with open(os.path.join(input_folder, filename), "r", encoding="utf-8") as f:
                correct_list = json.load(f)
        if 'good' in filename:
            with open(os.path.join(input_folder, filename), "r", encoding="utf-8") as f:
                good_cases = json.load(f)
            good_list = [item['img_id'] for item in good_cases]
        if 'bad' in filename:
            with open(os.path.join(input_folder, filename), "r", encoding="utf-8") as f:
                bad_cases = json.load(f)
            bad_list = [item['img_id'] for item in bad_cases]
        if 'result' in filename and '.json' in filename:
            with open(os.path.join(input_folder, filename), "r", encoding="utf-8") as f:
                results = json.load(f)
    print("len(correct_list): ", len(correct_list))
    print("len(bad_list): ", len(bad_list))
    print("len(good_list): ", len(good_list))
    print("len(results): ", len(results))
    
    correct_but_bad = []
    for item in correct_list:
        if item in bad_list:
            correct_but_bad.append(item)
    print("correct_but_bad: ", correct_but_bad)
    correct_output_data ,correct_but_bad_output_data = [],[]
    for result in results:
        if result['img_id'] in correct_list:
            correct_output_data.append(result)
        if result['img_id'] in correct_but_bad:
            correct_but_bad_output_data.append(result)
    with open(os.path.join(input_folder, "all_correct.json"), "w", encoding="utf-8") as f:
        json.dump(correct_output_data, f, ensure_ascii=False, indent=4)
    with open(os.path.join(input_folder, "correct_but_not_matched.json"), "w", encoding="utf-8") as f:
        json.dump(correct_but_bad_output_data, f, ensure_ascii=False, indent=4)
            
    


    # save the evaluation metadata
    with open(summary_path, 'w', encoding='utf-8') as f_summary:
        json.dump(evaluation_summary, f_summary, indent=4, ensure_ascii=False)
    print(f"CDM evaluation summary saved to {summary_path}")

    return evaluation_summary



def plot_cdm_metrics(evaluation_summary: list, output_dir: str = 'cdm_output', plot_filename: str = 'cdm_metrics_plot.png'):
    """
    plot the visualization chart based on the evaluation metadata.

    :param summary_file: evaluation metadata JSON file path
    :param output_dir:    output directory for plotting
    :param plot_filename: plot file name
    """
    # read the evaluation metadata
    # with open(summary_file, 'r', encoding='utf-8') as f_summary:
    #     evaluation_summary = json.load(f_summary)

    # organize the data structure
    data = {}
    for entry in evaluation_summary:
        dataset = entry['dataset_name']
        step = entry['checkpoint_step']
        if dataset not in data:
            data[dataset] = {"steps": [], "mean_scores": [], "exp_rates": [], "invalid_counts": []}
        data[dataset]['steps'].append(step)
        data[dataset]['mean_scores'].append(entry['mean_score'])
        data[dataset]['exp_rates'].append(entry['exp_rate'])
        data[dataset]['invalid_counts'].append(entry['invalid_count'])

    # create the chart
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(21, 7))
    colors = plt.cm.tab10.colors  # use Tab10 color mapping
    color_map = {}
    for idx, dataset in enumerate(sorted(data.keys())):
        color_map[dataset] = colors[idx % len(colors)]

    for dataset, metrics in data.items():
        sorted_indices = np.argsort(metrics['steps'])
        steps = np.array(metrics['steps'])[sorted_indices]
        mean_scores = np.array(metrics['mean_scores'])[sorted_indices]
        exp_rates = np.array(metrics['exp_rates'])[sorted_indices]
        invalid_counts = np.array(metrics['invalid_counts'])[sorted_indices]

        ax1.plot(steps, mean_scores, marker='o', linestyle='-', label=dataset, color=color_map[dataset])
        ax2.plot(steps, exp_rates, marker='s', linestyle='--', label=dataset, color=color_map[dataset])
        ax3.plot(steps, invalid_counts, marker='^', linestyle='-.', label=dataset, color=color_map[dataset])

    # set the chart properties
    ax1.set_title('Mean Score vs Step')
    ax1.set_xlabel('Checkpoint Step')
    ax1.set_ylabel('Mean Score')
    ax1.grid(True)
    ax1.legend(title='Dataset')

    ax2.set_title('Exp Rate vs Step')
    ax2.set_xlabel('Checkpoint Step')
    ax2.set_ylabel('Exp Rate (%)')
    ax2.grid(True)
    ax2.legend(title='Dataset')

    ax3.set_title('Invalid Count vs Step')
    ax3.set_xlabel('Checkpoint Step')
    ax3.set_ylabel('Number of Invalid Items')
    ax3.grid(True)
    ax3.legend(title='Dataset')

    plt.tight_layout()
    plot_path = os.path.join(output_dir, plot_filename)
    plt.savefig(plot_path)
    plt.close()
    print(f"Metrics plot saved to {plot_path}")