# result_statistic.py
import json
import logging
import os

import utils.safe_statistic

# set logging
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

# Configure input and output paths
result_dir_all = "../data/temp/receval_result"
trajectory_analyze_dir_all = "../data/trajectory-evaluation_by_llm"


def result_statistic(result_dir_all, trajectory_analyze_dir_all):
    os.makedirs(trajectory_analyze_dir_all, exist_ok=True)
    all_data = {}
    for model_type_name in os.listdir(result_dir_all):
        result_dir = os.path.join(result_dir_all, model_type_name)

        instance_id_list = []
        intra_correctness_list = []
        inter_correctness_list = []
        informativeness_list = []

        for fname in os.listdir(result_dir):
            if not fname.endswith(".json"):
                continue
            file_path = os.path.join(result_dir, fname)
            with open(file_path, 'r', encoding='utf-8') as f:
                receval_result = json.load(f)
                instance_id_list.append(fname[:-5])
                intra_correctness_list.append(receval_result['avg_intra_correctness'])
                inter_correctness_list.append(receval_result['avg_inter_correctness'])
                informativeness_list.append(receval_result['avg_informativeness'])

        trajectory_analyze = {}

        trajectory_analyze['all_data'] = {}
        trajectory_analyze['all_data']['total_count'] = len(instance_id_list)
        trajectory_analyze['all_data']['instance_id_list'] = instance_id_list
        trajectory_analyze['all_data']['intra_correctness_list'] = intra_correctness_list
        trajectory_analyze['all_data']['inter_correctness_list'] = inter_correctness_list
        trajectory_analyze['all_data']['informativeness_list'] = informativeness_list


        intra_correctness_instance_id_list = []
        inter_correctness_instance_id_list = []
        informativeness_instance_id_list = []
        intra_correctness_clean_list = []
        inter_correctness_clean_list = []
        informativeness_clean_list = []

        for index, instance_id in enumerate(instance_id_list):
            intra_correctness = intra_correctness_list[index]
            inter_correctness = inter_correctness_list[index]
            informativeness = informativeness_list[index]

            if not intra_correctness == 'None':
                intra_correctness_instance_id_list.append(instance_id)
                intra_correctness_clean_list.append(intra_correctness)

            if not inter_correctness == 'None':
                inter_correctness_instance_id_list.append(instance_id)
                inter_correctness_clean_list.append(inter_correctness)

            if not informativeness == 'None':
                informativeness_instance_id_list.append(instance_id)
                informativeness_clean_list.append(informativeness)

        trajectory_analyze['intra_correctness'] = {}
        trajectory_analyze['intra_correctness']['total_count'] = len(intra_correctness_clean_list)
        trajectory_analyze['intra_correctness']['instance_id_list'] = intra_correctness_instance_id_list
        trajectory_analyze['intra_correctness']['score_list'] = intra_correctness_clean_list
        trajectory_analyze['intra_correctness']['mean'] = utils.safe_statistic.mean(intra_correctness_clean_list)
        trajectory_analyze['intra_correctness']['median'] = utils.safe_statistic.median(intra_correctness_clean_list)

        trajectory_analyze['inter_correctness'] = {}
        trajectory_analyze['inter_correctness']['total_count'] = len(inter_correctness_clean_list)
        trajectory_analyze['inter_correctness']['instance_id_list'] = inter_correctness_instance_id_list
        trajectory_analyze['inter_correctness']['score_list'] = inter_correctness_clean_list
        trajectory_analyze['inter_correctness']['mean'] = utils.safe_statistic.mean(inter_correctness_clean_list)
        trajectory_analyze['inter_correctness']['median'] = utils.safe_statistic.median(inter_correctness_clean_list)

        trajectory_analyze['informativeness'] = {}
        trajectory_analyze['informativeness']['total_count'] = len(informativeness_clean_list)
        trajectory_analyze['informativeness']['instance_id_list'] = informativeness_instance_id_list
        trajectory_analyze['informativeness']['score_list'] = informativeness_clean_list
        trajectory_analyze['informativeness']['mean'] = utils.safe_statistic.mean(informativeness_clean_list)
        trajectory_analyze['informativeness']['median'] = utils.safe_statistic.median(informativeness_clean_list)

        model_name = model_type_name.split("_")[0]
        type = model_type_name.split("_")[1]
        if type == 'empty':
            type = 'empty_patch'

        all_data[model_name] = all_data.get(model_name, {})
        all_data[model_name][type] = trajectory_analyze

    for model_name,trajectory_analyze_data in all_data.items():
        trajectory_analyze_file = os.path.join(trajectory_analyze_dir_all, f"{model_name}_trajectory_analyze.json")

        with open(trajectory_analyze_file, 'w', encoding='utf-8') as f:
            f.write(json.dumps(trajectory_analyze_data, ensure_ascii=False, indent=4))


def main(result_dir_all=result_dir_all, trajectory_analyze_dir_all=trajectory_analyze_dir_all):
    result_statistic(result_dir_all, trajectory_analyze_dir_all)


if __name__ == '__main__':
    main()
