{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mr_eval.utils.utils import *\n",
    "import os\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "\n",
    "data_dir = \"/mnt/petrelfs/songmingyang/code/reasoning/MR_Hallucination/mr_eval/tasks/prmtest_classified/data\"\n",
    "dataset_type = \"dir_of_jsonl\"\n",
    "\n",
    "#domain_inconsistency\tredundency\tmulti_solutions\tdeception\tconfidence\tstep_contradiction\tcircular\tmissing_condition\tcounterfactual\n",
    "classification_name_dict = dict(\n",
    "    domain_inconsistency=\"DC.\",\n",
    "    redundency=\"NR.\",\n",
    "    multi_solutions=\"MS.\",\n",
    "    deception=\"DR.\",\n",
    "    confidence=\"CI.\",\n",
    "    step_contradiction=\"SC.\",\n",
    "    circular=\"NCL.\",\n",
    "    missing_condition=\"PS.\",\n",
    "    counterfactual=\"ES.\"\n",
    ")\n",
    "classifications = [\"redundency\", \"circular\", \"counterfactual\", \"step_contradiction\", \"domain_inconsistency\",  \"confidence\", \"missing_condition\", \"deception\", \"multi_solutions\", ]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def get_steps_info(raw_data,classifications):\n",
    "    meta_res_sample = dict(total_step_length=[], error_step_length=[],)\n",
    "    total_list = deepcopy(meta_res_sample)\n",
    "    classification_list = {classification: deepcopy(meta_res_sample) for classification in classifications}\n",
    "    \n",
    "    for item in raw_data:\n",
    "        steps_num = len(item[\"modified_process\"])\n",
    "        error_num = len(item[\"error_steps\"])\n",
    "        correct_num = steps_num - error_num\n",
    "        total_list[\"total_step_length\"].append(steps_num)\n",
    "        total_list[\"error_step_length\"].append(error_num)\n",
    "        classification_list[item[\"classification\"]][\"total_step_length\"].append(steps_num)\n",
    "        classification_list[item[\"classification\"]][\"error_step_length\"].append(error_num)\n",
    "    \n",
    "    for k,v in total_list.items():\n",
    "        total_list[k] = np.mean(v) if len(v) > 0 else -1\n",
    "    for classification in classifications:\n",
    "        for k,v in classification_list[classification].items():\n",
    "            classification_list[classification][k] = np.mean(v) if len(v) > 0 else -1\n",
    "            \n",
    "    return total_list, classification_list\n",
    "\n",
    "def get_first_error_loc(raw_data,classifications):\n",
    "    meta_res_sample = dict(first_error_loc=[])\n",
    "    total_list = deepcopy(meta_res_sample)\n",
    "    classification_list = {classification: deepcopy(meta_res_sample) for classification in classifications}\n",
    "    \n",
    "    for item in raw_data:\n",
    "\n",
    "        first_error_loc = item[\"error_steps\"][0] if len(item[\"error_steps\"]) > 0 else -1\n",
    "        total_list[\"first_error_loc\"].extend([first_error_loc] if first_error_loc != -1 else [])\n",
    "        classification_list[item[\"classification\"]][\"first_error_loc\"].extend([first_error_loc] if first_error_loc != -1 else [])\n",
    "    \n",
    "    for k,v in total_list.items():\n",
    "        total_list[k] = np.mean(v) if len(v) > 0 else -1\n",
    "    for classification in classifications:\n",
    "        for k,v in classification_list[classification].items():\n",
    "            classification_list[classification][k] = np.mean(v) if len(v) > 0 else -1\n",
    "            \n",
    "    return total_list, classification_list\n",
    "\n",
    "def get_quesiton_length(raw_data,classifications):\n",
    "    meta_res_sample = dict(question_length=[])\n",
    "    total_list = deepcopy(meta_res_sample)\n",
    "    classification_list = {classification: deepcopy(meta_res_sample) for classification in classifications}\n",
    "    \n",
    "    for item in raw_data:\n",
    "        question_length = len(item[\"modified_question\"])\n",
    "        total_list[\"question_length\"].extend([question_length])\n",
    "        classification_list[item[\"classification\"]][\"question_length\"].extend([question_length])\n",
    "    \n",
    "    for k,v in total_list.items():\n",
    "        total_list[k] = np.mean(v) if len(v) > 0 else -1\n",
    "    for classification in classifications:\n",
    "        for k,v in classification_list[classification].items():\n",
    "            classification_list[classification][k] = np.mean(v) if len(v) > 0 else -1\n",
    "            \n",
    "    return total_list, classification_list \n",
    "        \n",
    "def get_total_num(raw_data,classifications):\n",
    "    meta_res_sample = dict(total_num=0)\n",
    "    total_list = deepcopy(meta_res_sample)\n",
    "    classification_list = {classification: deepcopy(meta_res_sample) for classification in classifications}\n",
    "    \n",
    "    for item in raw_data:\n",
    "        total_list[\"total_num\"] += 1\n",
    "        classification_list[item[\"classification\"]][\"total_num\"] += 1\n",
    "            \n",
    "    return total_list, classification_list\n",
    "\n",
    "def merge_res_to_base(base_total_dict, base_classification_dict, merge_total_dict, merge_classification_dict):\n",
    "    for k,v in merge_total_dict.items():\n",
    "        base_total_dict[k] = v\n",
    "    for classification in base_classification_dict.keys():\n",
    "        for k,v in merge_classification_dict[classification].items():\n",
    "            base_classification_dict[classification][k] = v\n",
    "    return base_total_dict, base_classification_dict\n",
    "\n",
    "### Visualization\n",
    "def print_res_to_excel(total_statistic_data, classification_statistic_data, split_token=\"\\t\", return_token=\"\\n\",classifications=classifications):\n",
    "    metrics = list(total_statistic_data.keys())\n",
    "    \n",
    "    if classifications is None:\n",
    "        classifications = list(classification_statistic_data.keys())\n",
    "        \n",
    "    all_res_str = f\"Models{split_token}Total\"\n",
    "    for classification in classifications:\n",
    "        all_res_str += f\"{split_token}{classification_name_dict[classification]}\"\n",
    "    all_res_str += return_token\n",
    "    for metric in metrics:\n",
    "        write_val = round(total_statistic_data[metric],1) if isinstance(total_statistic_data[metric],float) else total_statistic_data[metric]\n",
    "        write_val = write_val if write_val != -1 else \"N/A\"\n",
    "        res_str = f\"{metric}{split_token}{write_val}\"\n",
    "        for classification in classifications:\n",
    "            write_val = round(classification_statistic_data[classification][metric],1) if isinstance(classification_statistic_data[classification][metric],float) else classification_statistic_data[classification][metric]\n",
    "            write_val = write_val if write_val != -1 else \"N/A\"\n",
    "            res_str += f\"{split_token}{write_val}\"\n",
    "        res_str += return_token\n",
    "        all_res_str += res_str\n",
    "    \n",
    "    all_res_str = all_res_str.replace(\"_\",\" \")\n",
    "    print(all_res_str)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_files = os.listdir(data_dir)\n",
    "data_files = [os.path.join(data_dir, file) for file in data_files if file.endswith(\".jsonl\")]\n",
    "raw_data = []\n",
    "for file in data_files:\n",
    "    temp = process_jsonl(file)\n",
    "    raw_data.extend(temp)\n",
    "\n",
    "classifications = set([item[\"classification\"] for item in raw_data])\n",
    "total_statistic_data = {}\n",
    "classification_statistic_data = {classification: {} for classification in classifications}\n",
    "\n",
    "step_info_total, step_info_classification = get_steps_info(raw_data,classifications)\n",
    "first_error_loc_total, first_error_loc_classification = get_first_error_loc(raw_data,classifications)\n",
    "question_length_total, question_length_classification = get_quesiton_length(raw_data,classifications)\n",
    "total_num_total, total_num_classification = get_total_num(raw_data,classifications)\n",
    "\n",
    "total_statistic_data, classification_statistic_data = merge_res_to_base(total_statistic_data, classification_statistic_data, step_info_total, step_info_classification)\n",
    "total_statistic_data, classification_statistic_data = merge_res_to_base(total_statistic_data, classification_statistic_data, first_error_loc_total, first_error_loc_classification)\n",
    "total_statistic_data, classification_statistic_data = merge_res_to_base(total_statistic_data, classification_statistic_data, question_length_total, question_length_classification)\n",
    "total_statistic_data, classification_statistic_data = merge_res_to_base(total_statistic_data, classification_statistic_data, total_num_total, total_num_classification)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Models\tTotal\tNR.\tNCL.\tES.\tSC.\tDC.\tCI.\tPS.\tDR.\tMS.\n",
      "total step length\t13.4\t15.3\t10.3\t13.8\t14.2\t13.3\t14.2\t12.7\t13.4\t14.1\n",
      "error step length\t2.1\t2.0\t2.8\t2.8\t1.6\t1.8\t1.7\t2.5\t2.3\t0.0\n",
      "first error loc\t7.8\t7.8\t4.9\t8.0\t9.1\t6.8\t11.4\t6.2\t8.3\tN/A\n",
      "question length\t152.7\t153.6\t152.5\t153.5\t149.7\t152.5\t152.7\t158.0\t153.5\t132.2\n",
      "total num\t6216\t758\t758\t757\t758\t757\t757\t756\t750\t165\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print_res_to_excel(total_statistic_data, classification_statistic_data,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Models&Total&NR.&NCL.&ES.&SC.&DC.&CI.&PS.&DR.&MS.\\\\\n",
      "total step length&13.4&15.3&10.3&13.8&14.2&13.3&14.2&12.7&13.4&14.1\\\\\n",
      "error step length&2.1&2.0&2.8&2.8&1.6&1.8&1.7&2.5&2.3&0.0\\\\\n",
      "first error loc&7.8&7.8&4.9&8.0&9.1&6.8&11.4&6.2&8.3&N/A\\\\\n",
      "question length&152.7&153.6&152.5&153.5&149.7&152.5&152.7&158.0&153.5&132.2\\\\\n",
      "total num&6216&758&758&757&758&757&757&756&750&165\\\\\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print_res_to_excel(total_statistic_data, classification_statistic_data,split_token=\"&\", return_token=\"\\\\\\\\\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "smoe",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
