{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot For JudgeBench"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Round Number Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from multi_llm_debate.analysis.utils import count_files_per_directory, combine_plots\n",
    "from multi_llm_debate.analysis.plot_correct_rate_distribution import (\n",
    "    plot_file_count_distribution,\n",
    ")\n",
    "from matplotlib import pyplot as plt\n",
    "from typing import Tuple\n",
    "\n",
    "\n",
    "def main(model_dir_path: str) -> Tuple[plt.Figure, plt.Axes]:\n",
    "    \"\"\"Main function to analyze and visualize file count distribution.\n",
    "\n",
    "    Args:\n",
    "        model_dir_path: Path to the model directory containing data.\n",
    "\n",
    "    Returns:\n",
    "        A tuple containing the figure and axes of the plot.\n",
    "    \"\"\"\n",
    "    print(f\"Analyzing files in: {model_dir_path}\")\n",
    "\n",
    "    # Get file count distribution\n",
    "    distribution = count_files_per_directory(model_dir_path)\n",
    "\n",
    "    # Display summary\n",
    "    # if distribution:\n",
    "    #     print(\"\\nFile count distribution across directories:\")\n",
    "    #     print(f\"{'Files':<10} {'Directories':>12}\")\n",
    "    #     print(\"-\" * 22)\n",
    "    #     for count, num_dirs in sorted(distribution.items()):\n",
    "    #         print(f\"{count:<10} {num_dirs:>12}\")\n",
    "\n",
    "    #     total_dirs = sum(distribution.values())\n",
    "    #     print(f\"\\nTotal directories analyzed: {total_dirs}\")\n",
    "    #     print(f\"Average files per directory: {sum(k*v for k,v in distribution.items())/total_dirs:.2f}\")\n",
    "    # else:\n",
    "    #     print(\"No directory data found.\")\n",
    "    model_config = model_dir_path.split(\"/\")[-1]\n",
    "    task_name = \"Judge Bench\"\n",
    "    # Create visualization\n",
    "    figure, ax = plot_file_count_distribution(\n",
    "        distribution=distribution,\n",
    "        model_config=model_config,\n",
    "        task_name=task_name,\n",
    "    )\n",
    "    return figure, ax\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    fig1, ax1 = main(\"../data/big_bench_pruning/gemini-2_0-flash-001(11)\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()  # Display the combined plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Accuracy by Round"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### JudgeBench"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    import logging\n",
    "    from pathlib import Path\n",
    "\n",
    "    from multi_llm_debate.run.judge_bench.utils import (\n",
    "        compare_judge_bench_response,\n",
    "        extract_caption_a_b_answer,\n",
    "        load_judge_bench_dataset,\n",
    "    )\n",
    "    from multi_llm_debate.analysis.plot_correct_rate_distribution import (\n",
    "        combine_correct_rate_plots,\n",
    "    )\n",
    "    logger = logging.getLogger(__name__)\n",
    "    logger.setLevel(logging.ERROR)\n",
    "\n",
    "    df = load_judge_bench_dataset()\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/judge_bench\")\n",
    "    task_name = \"Judge Bench\"\n",
    "\n",
    "    model_dirs = [\n",
    "        Path(\"../data/judge_bench/gemma-3-4b-it(7)\"),\n",
    "        Path(\"../data/judge_bench/Llama-3_1-8B-Instruct(7)\"),\n",
    "        Path(\"../data/judge_bench/Qwen2_5-7B-Instruct(7)\"),\n",
    "        Path(\"../data/judge_bench/gemini-2_0-flash-001(7)\"),\n",
    "    ]\n",
    "\n",
    "    model_configs = [\n",
    "        \"Gemini-3-4B\",\n",
    "        \"Llama-3-1-8B\",\n",
    "        \"Qwen-2.5-7B\",\n",
    "        \"Gemini-2.0-Flash\",\n",
    "    ]\n",
    "    extract_func = extract_caption_a_b_answer\n",
    "    compare_func = compare_judge_bench_response\n",
    "    max_rounds = 6\n",
    "    rows = 2\n",
    "    columns = 3\n",
    "    show_plot = True\n",
    "    combined_title = \"Distribution of the number of correct agents across debate rounds for the JudgeBench dataset\"\n",
    "    file_name = \"judge_bench_combined_correct_rate_plots.png\"\n",
    "    progress_bar = True\n",
    "    combine_correct_rate_plots(\n",
    "        df,\n",
    "        model_dirs,\n",
    "        model_configs,\n",
    "        OUTPUT_DIR,\n",
    "        extract_func,\n",
    "        compare_func,\n",
    "        task_name=task_name,\n",
    "        max_rounds=max_rounds,\n",
    "        rows=rows,\n",
    "        columns=columns,\n",
    "        show_plot=show_plot,\n",
    "        combined_title=combined_title,\n",
    "        file_name=file_name,\n",
    "        progress_bar=progress_bar,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TruthfulQA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    import os\n",
    "    import logging\n",
    "    from pathlib import Path\n",
    "\n",
    "    from multi_llm_debate.analysis.plot_correct_rate_distribution import (\n",
    "        combine_correct_rate_plots,\n",
    "    )\n",
    "    from multi_llm_debate.run.truthful_qa.utils import (\n",
    "        extract_caption_a_b_c_answer,\n",
    "        compare_truthful_qa_response,\n",
    "        load_truthful_qa_dataset,\n",
    "    )\n",
    "\n",
    "    logger = logging.getLogger(__name__)\n",
    "    logger.setLevel(logging.ERROR)\n",
    "\n",
    "    df = load_truthful_qa_dataset()\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/truthful_qa\")\n",
    "    task_name = \"TruthfulQA\"\n",
    "\n",
    "    model_dirs = [\n",
    "        Path(\"../data/truthful_qa/gemma-3-4b-it(7)\"),\n",
    "        Path(\"../data/truthful_qa/Llama-3_1-8B-Instruct(7)\"),\n",
    "        Path(\"../data/truthful_qa/Qwen2_5-7B-Instruct(7)\"),\n",
    "        Path(\"../data/truthful_qa/gemini-2_0-flash-001(7)\"),\n",
    "    ]\n",
    "\n",
    "    model_configs = [\n",
    "        \"Gemini-3-4B\",\n",
    "        \"Llama-3-1-8B\",\n",
    "        \"Qwen-2.5-7B\",\n",
    "        \"Gemini-2.0-Flash\",\n",
    "    ]\n",
    "    extract_func = extract_caption_a_b_answer\n",
    "    compare_func = compare_truthful_qa_response\n",
    "    max_rounds = 6\n",
    "    rows = 2\n",
    "    columns = 3\n",
    "    show_plot = True\n",
    "    combined_title = \"Distribution of the number of correct agents across debate rounds for the TruthfulQA dataset\"\n",
    "    file_name = \"truthful_qa_combined_correct_rate_plots.png\"\n",
    "    progress_bar = True\n",
    "    combine_correct_rate_plots(\n",
    "        df,\n",
    "        model_dirs,\n",
    "        model_configs,\n",
    "        OUTPUT_DIR,\n",
    "        extract_func,\n",
    "        compare_func,\n",
    "        task_name=task_name,\n",
    "        max_rounds=max_rounds,\n",
    "        rows=rows,\n",
    "        columns=columns,\n",
    "        show_plot=show_plot,\n",
    "        combined_title=combined_title,\n",
    "        file_name=file_name,\n",
    "        progress_bar=progress_bar,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    import os\n",
    "    import logging\n",
    "    from pathlib import Path\n",
    "\n",
    "    from multi_llm_debate.analysis.plot_correct_rate_distribution import (\n",
    "        combine_correct_rate_plots,\n",
    "    )\n",
    "    from multi_llm_debate.run.llm_bar.utils import (\n",
    "        extract_1_2_answer,\n",
    "        compare_llm_bar_response,\n",
    "        load_llm_bar_dataset,\n",
    "    )\n",
    "\n",
    "    logger = logging.getLogger(__name__)\n",
    "    logger.setLevel(logging.ERROR)\n",
    "\n",
    "    df = load_llm_bar_dataset(\"../datasets/LLMBar\")\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/llm_bar\")\n",
    "    task_name = \"LLMBar\"\n",
    "\n",
    "    model_dirs = [\n",
    "        Path(\"../data/llm_bar/gemma-3-4b-it(7)\"),\n",
    "        Path(\"../data/llm_bar/Llama-3_1-8B-Instruct(7)\"),\n",
    "        Path(\"../data/llm_bar/Qwen2_5-7B-Instruct(7)\"),\n",
    "        Path(\"../data/llm_bar/gemini-2_0-flash-001(7)\"),\n",
    "    ]\n",
    "\n",
    "    model_configs = [\n",
    "        \"Gemini-3-4B\",\n",
    "        \"Llama-3-1-8B\",\n",
    "        \"Qwen-2.5-7B\",\n",
    "        \"Gemini-2.0-Flash\",\n",
    "    ]\n",
    "    extract_func = extract_1_2_answer\n",
    "    compare_func = compare_llm_bar_response\n",
    "    max_rounds = 6\n",
    "    rows = 2\n",
    "    columns = 3\n",
    "    show_plot = True\n",
    "    combined_title = \"Distribution of the number of correct agents across debate rounds for the LLMBar dataset\"\n",
    "    file_name = \"llm_bar_combined_correct_rate_plots.png\"\n",
    "    progress_bar = True\n",
    "    combine_correct_rate_plots(\n",
    "        df,\n",
    "        model_dirs,\n",
    "        model_configs,\n",
    "        OUTPUT_DIR,\n",
    "        extract_func,\n",
    "        compare_func,\n",
    "        task_name=task_name,\n",
    "        max_rounds=max_rounds,\n",
    "        rows=rows,\n",
    "        columns=columns,\n",
    "        show_plot=show_plot,\n",
    "        combined_title=combined_title,\n",
    "        file_name=file_name,\n",
    "        progress_bar=progress_bar,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LLMBar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    import os\n",
    "    import logging\n",
    "    from pathlib import Path\n",
    "\n",
    "    from multi_llm_debate.analysis.plot_correct_rate_distribution import (\n",
    "        combine_correct_rate_plots,\n",
    "    )\n",
    "    from multi_llm_debate.run.llm_bar.utils import (\n",
    "        extract_1_2_answer,\n",
    "        compare_llm_bar_response,\n",
    "        load_llm_bar_dataset,\n",
    "    )\n",
    "\n",
    "    logger = logging.getLogger(__name__)\n",
    "    logger.setLevel(logging.ERROR)\n",
    "\n",
    "    df = load_llm_bar_dataset(\"../datasets/LLMBar\")\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/llm_bar\")\n",
    "    task_name = \"LLMBar\"\n",
    "\n",
    "    model_dirs = [\n",
    "        Path(\"../data/llm_bar/gemma-3-4b-it(7)\"),\n",
    "        Path(\"../data/llm_bar/Llama-3_1-8B-Instruct(7)\"),\n",
    "        Path(\"../data/llm_bar/Qwen2_5-7B-Instruct(7)\"),\n",
    "        Path(\"../data/llm_bar/gemini-2_0-flash-001(7)\"),\n",
    "    ]\n",
    "\n",
    "    model_configs = [\n",
    "        \"Gemini-3-4B\",\n",
    "        \"Llama-3-1-8B\",\n",
    "        \"Qwen-2.5-7B\",\n",
    "        \"Gemini-2.0-Flash\",\n",
    "    ]\n",
    "    extract_func = extract_1_2_answer\n",
    "    compare_func = compare_llm_bar_response\n",
    "    max_rounds = 6\n",
    "    rows = 2\n",
    "    columns = 3\n",
    "    show_plot = True\n",
    "    combined_title = \"Distribution of the number of correct agents across debate rounds for the LLMBar dataset\"\n",
    "    file_name = \"llm_bar_combined_correct_rate_plots.png\"\n",
    "    progress_bar = True\n",
    "    combine_correct_rate_plots(\n",
    "        df,\n",
    "        model_dirs,\n",
    "        model_configs,\n",
    "        OUTPUT_DIR,\n",
    "        extract_func,\n",
    "        compare_func,\n",
    "        task_name=task_name,\n",
    "        max_rounds=max_rounds,\n",
    "        rows=rows,\n",
    "        columns=columns,\n",
    "        show_plot=show_plot,\n",
    "        combined_title=combined_title,\n",
    "        file_name=file_name,\n",
    "        progress_bar=progress_bar,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BIG_Bench"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    from multi_llm_debate.run.judge_bench.utils import (\n",
    "        load_judge_bench_dataset,\n",
    "        extract_caption_a_b_answer,\n",
    "        compare_judge_bench_response,\n",
    "    )\n",
    "    from multi_llm_debate.analysis.plot_accuracy import (\n",
    "        process_model_majority_aggregated,\n",
    "    )\n",
    "    import os\n",
    "    import pandas as pd\n",
    "    from pathlib import Path\n",
    "\n",
    "    max_round_number = 10\n",
    "    df_path = Path(\"../output/judge_bench/processed_data.csv\")\n",
    "    if not df_path.exists():\n",
    "        df = load_judge_bench_dataset(dataset_path=\"../datasets/JudgeBench\")\n",
    "        os.makedirs(\"../output/judge_bench\", exist_ok=True)\n",
    "        df.to_csv(df_path, index=False)\n",
    "    else:\n",
    "        df = pd.read_csv(df_path)\n",
    "    task_name = \"Judge Bench\"\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/judge_bench\")\n",
    "\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/judge_bench/Llama-3_2-3B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/judge_bench/Qwen2_5-3B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/judge_bench/Llama-3_1-8B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/judge_bench/Qwen2_5-7B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\n",
    "            \"../data/judge_bench/Llama-3_1-8B-Instruct(6)+Qwen2_5-7B-Instruct(5)\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    from multi_llm_debate.run.judge_bench.utils import (\n",
    "        load_judge_bench_dataset,\n",
    "        extract_caption_a_b_answer,\n",
    "        compare_judge_bench_response,\n",
    "    )\n",
    "    from multi_llm_debate.analysis.plot_accuracy import (\n",
    "        process_model_majority_aggregated,\n",
    "    )\n",
    "    import os\n",
    "    import pandas as pd\n",
    "    from pathlib import Path\n",
    "\n",
    "    max_round_number = 10\n",
    "    df_path = Path(\"../output/judge_bench/processed_data.csv\")\n",
    "    if not df_path.exists():\n",
    "        df = load_judge_bench_dataset(dataset_path=\"../datasets/JudgeBench\")\n",
    "        os.makedirs(\"../output/judge_bench\", exist_ok=True)\n",
    "        df.to_csv(df_path, index=False)\n",
    "    else:\n",
    "        df = pd.read_csv(df_path)\n",
    "    task_name = \"Judge Bench\"\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/judge_bench\")\n",
    "\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/judge_bench/Llama-3_2-3B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/judge_bench/Qwen2_5-3B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/judge_bench/Llama-3_1-8B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/judge_bench/Qwen2_5-7B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\n",
    "            \"../data/judge_bench/Llama-3_1-8B-Instruct(6)+Qwen2_5-7B-Instruct(5)\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    from multi_llm_debate.analysis.plot_accuracy import (\n",
    "        process_model_majority_aggregated,\n",
    "    )\n",
    "    import os\n",
    "    import pandas as pd\n",
    "    from pathlib import Path\n",
    "    from multi_llm_debate.run.truthful_qa.utils import (\n",
    "        load_truthful_qa_dataset,\n",
    "        preprocess_truthful_qa_dataframe,\n",
    "        extract_caption_a_b_c_answer,\n",
    "        compare_truthful_qa_response,\n",
    "    )\n",
    "    from multi_llm_debate.run.llm_bar.utils import (\n",
    "        extract_1_2_answer,\n",
    "        compare_llm_bar_response,\n",
    "        preprocess_llm_bar_dataframe,\n",
    "        load_llm_bar_dataset,\n",
    "    )\n",
    "\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/truthful_qa\")\n",
    "    df_path = Path(\"../output/truthful_qa/processed_data.csv\")\n",
    "    if not df_path.exists():\n",
    "        df = load_truthful_qa_dataset(dataset_path=\"../datasets/TruthfulQA\")\n",
    "        df = preprocess_truthful_qa_dataframe(df)\n",
    "        os.makedirs(\"../output/truthful_qa\", exist_ok=True)\n",
    "        df.to_csv(df_path, index=False)\n",
    "    else:\n",
    "        df = pd.read_csv(df_path)\n",
    "    task_name = \"Truthful QA\"\n",
    "\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/truthful_qa/Llama-3_1-8B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_c_answer,\n",
    "        compare_func=compare_truthful_qa_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/truthful_qa/Qwen2_5-7B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_caption_a_b_c_answer,\n",
    "        compare_func=compare_truthful_qa_response,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    from multi_llm_debate.analysis.plot_accuracy import (\n",
    "        process_model_majority_aggregated,\n",
    "    )\n",
    "    import os\n",
    "    import pandas as pd\n",
    "    from pathlib import Path\n",
    "    from multi_llm_debate.run.llm_bar.utils import (\n",
    "        extract_1_2_answer,\n",
    "        compare_llm_bar_response,\n",
    "        preprocess_llm_bar_dataframe,\n",
    "        load_llm_bar_dataset,\n",
    "    )\n",
    "\n",
    "    df_path = Path(\"../output/llm_bar/processed_data.csv\")\n",
    "    if not df_path.exists():\n",
    "        df = load_llm_bar_dataset(dataset_path=\"../datasets/LLMBar\")\n",
    "        df = preprocess_llm_bar_dataframe(df)\n",
    "        os.makedirs(\"../output/llm_bar\", exist_ok=True)\n",
    "        df.to_csv(df_path, index=False)\n",
    "    else:\n",
    "        df = pd.read_csv(df_path)\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/llm_bar\")\n",
    "    task_name = \"LLM Bar\"\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/llm_bar/Llama-3_1-8B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_1_2_answer,\n",
    "        compare_func=compare_llm_bar_response,\n",
    "    )\n",
    "    process_model_majority_aggregated(\n",
    "        model_dir=Path(\"../data/llm_bar/Qwen2_5-7B-Instruct(11)\"),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        task_name=task_name,\n",
    "        dataframe=df,\n",
    "        max_round_number=max_round_number,\n",
    "        extract_func=extract_1_2_answer,\n",
    "        compare_func=compare_llm_bar_response,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "\n",
    "from multi_llm_debate.analysis.calculate_correct_rate_by_round import (\n",
    "    calculate_correct_rate_by_round,\n",
    ")\n",
    "from multi_llm_debate.analysis.calculate_task_accuracy import analyze_task_accuracy\n",
    "from multi_llm_debate.run.judge_bench.utils import (\n",
    "    extract_caption_a_b_answer,\n",
    "    compare_judge_bench_response,\n",
    ")\n",
    "\n",
    "# Maximum round number for your correct_rate_by_round function\n",
    "MAX_ROUND_NUMBER = 10\n",
    "\n",
    "\n",
    "def create_plot_absolute_only(\n",
    "    absolute_by_round: Dict[float, np.ndarray],\n",
    "    model_name: str,\n",
    "    task_name: str = \"Judge Bench\",\n",
    ") -> None:\n",
    "    \"\"\"Plots the absolute accuracy value by round.\n",
    "\n",
    "    Args:\n",
    "        absolute_by_round: Dictionary mapping accuracy values to absolute\n",
    "            metrics by round.\n",
    "        model_name: Name of the model for the plot title.\n",
    "        task_name: Name of the task (default is \"Judge Bench\").\n",
    "    \"\"\"\n",
    "    # Sort the dictionary items by the accuracy value (the dictionary key)\n",
    "    sorted_items = sorted(absolute_by_round.items(), key=lambda x: x[0])\n",
    "\n",
    "    # Use a color map for different accuracy values\n",
    "    accuracy_colors = plt.cm.get_cmap(\"tab20\", len(sorted_items))\n",
    "\n",
    "    plt.figure(figsize=(10, 6))\n",
    "\n",
    "    legend_handles = []\n",
    "\n",
    "    # Plot a line for each unique accuracy value\n",
    "    for idx, (accuracy, absolute_values) in enumerate(sorted_items):\n",
    "        # Skip empty arrays or None values\n",
    "        if absolute_values is None or len(absolute_values) == 0:\n",
    "            continue\n",
    "\n",
    "        # Create x-axis values matching the length of absolute_values\n",
    "        rounds = np.arange(len(absolute_values))\n",
    "\n",
    "        # Only plot if both rounds and values have the same length\n",
    "        if len(rounds) > 0 and len(rounds) == len(absolute_values):\n",
    "            # Get color for this specific accuracy value\n",
    "            color = accuracy_colors(idx)\n",
    "\n",
    "            (line,) = plt.plot(\n",
    "                rounds,\n",
    "                absolute_values,\n",
    "                color=color,\n",
    "                linestyle=\"-\",  # Solid line for absolute\n",
    "                linewidth=2,\n",
    "                label=f\"Acc={accuracy:.2f} (absolute)\",\n",
    "            )\n",
    "            legend_handles.append(line)\n",
    "\n",
    "    # Title and labels\n",
    "    plt.title(f\"Absolute Correct Rate by Round: {model_name} - {task_name}\", pad=15)\n",
    "    plt.xlabel(\"Round Number\")\n",
    "    plt.ylabel(\"Correct Rate\")\n",
    "    plt.grid(True, linestyle=\"--\", alpha=0.7)\n",
    "\n",
    "    if legend_handles:\n",
    "        plt.legend(handles=legend_handles)\n",
    "\n",
    "    # Set y-axis limits and ticks\n",
    "    plt.ylim(0, 1)\n",
    "    plt.yticks(np.arange(0, 1.1, 0.1))\n",
    "    plt.xticks(range(min(11, MAX_ROUND_NUMBER + 1)))\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def process_model_absolute_only(\n",
    "    model_dir: Path, task_name: str = \"Judge Bench\"\n",
    ") -> None:\n",
    "    \"\"\"Process model data and create visualizations for the absolute correct rate.\n",
    "\n",
    "    Args:\n",
    "        model_dir: Path to the model directory containing debate data.\n",
    "\n",
    "    This function:\n",
    "    1) Analyzes accuracy\n",
    "    2) Calculates absolute correct rates for each unique accuracy value\n",
    "    3) Plots the results for the absolute correct rate\n",
    "    \"\"\"\n",
    "    model_name = model_dir.name\n",
    "    print(f\"\\nProcessing model: {model_name}\")\n",
    "\n",
    "    # 1) Analyze accuracy\n",
    "    result_df = analyze_task_accuracy(\n",
    "        model_dir=model_dir,\n",
    "        dataframe=df,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "    )\n",
    "\n",
    "    # 2) Get all unique accuracy values from the result dataframe\n",
    "    unique_accuracies = result_df[\"accuracy\"].unique()\n",
    "\n",
    "    # 3) Create a dictionary to store the absolute metric by round for each accuracy\n",
    "    absolute_by_round = {}\n",
    "\n",
    "    length = len(result_df)\n",
    "    # 4) For each unique accuracy value, calculate absolute correct rates by round\n",
    "    for accuracy in unique_accuracies:\n",
    "        if accuracy < 0:\n",
    "            continue\n",
    "\n",
    "        # Filter tasks by accuracy\n",
    "        filtered_df = result_df[result_df[\"accuracy\"] == accuracy]\n",
    "\n",
    "        # Calculate and print the percentage of tasks with this accuracy\n",
    "        accuracy_percentage = (len(filtered_df) / length) * 100\n",
    "        print(f\"Accuracy = {accuracy:.2f}: {accuracy_percentage:.2f}% of total tasks\")\n",
    "\n",
    "        try:\n",
    "            # Calculate absolute correct rates for this accuracy\n",
    "            cr_filtered_df = calculate_correct_rate_by_round(\n",
    "                filtered_df,\n",
    "                model_dir,\n",
    "                max_round_number=MAX_ROUND_NUMBER,\n",
    "                extract_func=extract_caption_a_b_answer,\n",
    "                compare_func=compare_judge_bench_response,\n",
    "            )\n",
    "\n",
    "            # Check if we have results for the absolute metric\n",
    "            absolute_rows = cr_filtered_df[cr_filtered_df[\"metric\"] == \"absolute\"]\n",
    "            if not absolute_rows.empty:\n",
    "                absolute_by_round[accuracy] = absolute_rows.iloc[0, 2:].values\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"Error processing accuracy {accuracy}: {e}\")\n",
    "            continue\n",
    "\n",
    "    # 5 Create the plot for absolute correct rates\n",
    "    create_plot_absolute_only(absolute_by_round, model_name, task_name)\n",
    "\n",
    "\n",
    "# ==========================\n",
    "# 3. MAIN SCRIPT\n",
    "# ==========================\n",
    "if __name__ == \"__main__\":\n",
    "    from multi_llm_debate.run.judge_bench.utils import load_judge_bench_dataset\n",
    "    import os\n",
    "\n",
    "    df = load_judge_bench_dataset(dataset_path=\"../datasets/JudgeBench\")\n",
    "    os.makedirs(\"../output/judge_bench\", exist_ok=True)\n",
    "    df_path = Path(\"../output/judge_bench/processed_data.csv\")\n",
    "    df.to_csv(df_path, index=False)\n",
    "\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/judge_bench\")\n",
    "\n",
    "    process_model_absolute_only(\n",
    "        Path(\"../data/judge_bench/Llama-3_2-3B-Instruct(11)\"), task_name\n",
    "    )\n",
    "    process_model_absolute_only(\n",
    "        Path(\"../data/judge_bench/Qwen2_5-3B-Instruct(11)\"), task_name\n",
    "    )\n",
    "    process_model_absolute_only(\n",
    "        Path(\"../data/judge_bench/Llama-3_1-8B-Instruct(11)\"), task_name\n",
    "    )\n",
    "    process_model_absolute_only(\n",
    "        Path(\"../data/judge_bench/Qwen2_5-7B-Instruct(11)\"), task_name\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    from pathlib import Path\n",
    "    from multi_llm_debate.run.judge_bench.utils import (\n",
    "        compare_judge_bench_response,\n",
    "        extract_caption_a_b_answer,\n",
    "    )\n",
    "    from multi_llm_debate.distribution_model.visualize_model import run_visualization\n",
    "\n",
    "    FIT_METHOD = \"direct\"  # \"direct\" or \"em\" optimization approach\n",
    "    N_RESTARTS = 2  # Number of random restarts for more stable fitting\n",
    "    ENFORCE_INCREASING = False  # Enforce non-decreasing expected success probability\n",
    "    MAX_ROUNDS = None  # or an int\n",
    "\n",
    "    # Analysis settings\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/judge_bench\")\n",
    "    task_name = \"Judge Bench\"\n",
    "    run_visualization(\n",
    "        answers_csv_path=Path(\"../output/judge_bench/processed_data.csv\"),\n",
    "        debates_csv_path=Path(\n",
    "            \"../data/judge_bench/Llama-3_2-3B-Instruct(11)/debate_rounds.csv\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        max_rounds=MAX_ROUNDS,\n",
    "        fitting_method=FIT_METHOD,\n",
    "        n_restarts=N_RESTARTS,\n",
    "        verbose=False,\n",
    "        enforce_increasing_success=ENFORCE_INCREASING,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "        model_config=\"Llama-3_2-3B-Instruct(11)\",\n",
    "        row_number=2,\n",
    "        task_name=task_name,\n",
    "    )\n",
    "    run_visualization(\n",
    "        answers_csv_path=Path(\"../output/judge_bench/processed_data.csv\"),\n",
    "        debates_csv_path=Path(\n",
    "            \"../data/judge_bench/Qwen2_5-3B-Instruct(11)/debate_rounds.csv\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        max_rounds=MAX_ROUNDS,\n",
    "        fitting_method=FIT_METHOD,\n",
    "        n_restarts=N_RESTARTS,\n",
    "        verbose=False,\n",
    "        enforce_increasing_success=ENFORCE_INCREASING,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "        model_config=\"Qwen2_5-3B-Instruct(11)\",\n",
    "        row_number=2,\n",
    "        task_name=task_name,\n",
    "    )\n",
    "    run_visualization(\n",
    "        answers_csv_path=Path(\"../output/judge_bench/processed_data.csv\"),\n",
    "        debates_csv_path=Path(\n",
    "            \"../data/judge_bench/Llama-3_1-8B-Instruct(11)/debate_rounds.csv\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        max_rounds=MAX_ROUNDS,\n",
    "        fitting_method=FIT_METHOD,\n",
    "        n_restarts=N_RESTARTS,\n",
    "        verbose=False,\n",
    "        enforce_increasing_success=ENFORCE_INCREASING,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "        model_config=\"Llama-3_1-8B-Instruct(11)\",\n",
    "        row_number=2,\n",
    "        task_name=task_name,\n",
    "    )\n",
    "    run_visualization(\n",
    "        answers_csv_path=Path(\"../output/judge_bench/processed_data.csv\"),\n",
    "        debates_csv_path=Path(\n",
    "            \"../data/judge_bench/Qwen2_5-7B-Instruct(11)/debate_rounds.csv\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        max_rounds=MAX_ROUNDS,\n",
    "        fitting_method=FIT_METHOD,\n",
    "        n_restarts=N_RESTARTS,\n",
    "        verbose=False,\n",
    "        enforce_increasing_success=ENFORCE_INCREASING,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "        model_config=\"Qwen2_5-7B-Instruct(11)\",\n",
    "        row_number=2,\n",
    "        task_name=task_name,\n",
    "    )\n",
    "    run_visualization(\n",
    "        answers_csv_path=Path(\"../output/judge_bench/processed_data.csv\"),\n",
    "        debates_csv_path=Path(\n",
    "            \"../data/judge_bench/Llama-3_1-8B-Instruct(6)+Qwen2_5-7B-Instruct(5)/debate_rounds.csv\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        max_rounds=MAX_ROUNDS,\n",
    "        fitting_method=FIT_METHOD,\n",
    "        n_restarts=N_RESTARTS,\n",
    "        verbose=False,\n",
    "        enforce_increasing_success=ENFORCE_INCREASING,\n",
    "        extract_func=extract_caption_a_b_answer,\n",
    "        compare_func=compare_judge_bench_response,\n",
    "        model_config=\"Llama-3_1-8B-Instruct(6)+Qwen2_5-7B-Instruct(5)\",\n",
    "        row_number=2,\n",
    "        task_name=task_name,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    from pathlib import Path\n",
    "    from multi_llm_debate.distribution_model.visualize_model import run_visualization\n",
    "    from multi_llm_debate.run.truthful_qa.utils import (\n",
    "        extract_caption_a_b_c_answer,\n",
    "        compare_truthful_qa_response,\n",
    "    )\n",
    "\n",
    "    FIT_METHOD = \"direct\"  # \"direct\" or \"em\" optimization approach\n",
    "    N_RESTARTS = 2  # Number of random restarts for more stable fitting\n",
    "    ENFORCE_INCREASING = False  # Enforce non-decreasing expected success probability\n",
    "    MAX_ROUNDS = None  # or an int\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/truthful_qa\")\n",
    "    task_name = \"Truthful QA\"\n",
    "    run_visualization(\n",
    "        answers_csv_path=Path(\"../output/truthful_qa/processed_data.csv\"),\n",
    "        debates_csv_path=Path(\n",
    "            \"../data/truthful_qa/Llama-3_1-8B-Instruct(11)/debate_rounds.csv\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        max_rounds=MAX_ROUNDS,\n",
    "        fitting_method=FIT_METHOD,\n",
    "        n_restarts=N_RESTARTS,\n",
    "        verbose=False,\n",
    "        enforce_increasing_success=ENFORCE_INCREASING,\n",
    "        extract_func=extract_caption_a_b_c_answer,\n",
    "        compare_func=compare_truthful_qa_response,\n",
    "        model_config=\"Llama-3_1-8B-Instruct(11)\",\n",
    "        row_number=2,\n",
    "        task_name=task_name,\n",
    "    )\n",
    "    run_visualization(\n",
    "        answers_csv_path=Path(\"../output/truthful_qa/processed_data.csv\"),\n",
    "        debates_csv_path=Path(\n",
    "            \"../data/truthful_qa/Qwen2_5-7B-Instruct(11)/debate_rounds.csv\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        max_rounds=MAX_ROUNDS,\n",
    "        fitting_method=FIT_METHOD,\n",
    "        n_restarts=N_RESTARTS,\n",
    "        verbose=False,\n",
    "        enforce_increasing_success=ENFORCE_INCREASING,\n",
    "        extract_func=extract_caption_a_b_c_answer,\n",
    "        compare_func=compare_truthful_qa_response,\n",
    "        model_config=\"Qwen2_5-7B-Instruct(11)\",\n",
    "        row_number=2,\n",
    "        task_name=task_name,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    from pathlib import Path\n",
    "    from multi_llm_debate.distribution_model.visualize_model import run_visualization\n",
    "    from multi_llm_debate.run.llm_bar.utils import (\n",
    "        extract_1_2_answer,\n",
    "        compare_llm_bar_response,\n",
    "    )\n",
    "\n",
    "    FIT_METHOD = \"direct\"  # \"direct\" or \"em\" optimization approach\n",
    "    N_RESTARTS = 2  # Number of random restarts for more stable fitting\n",
    "    ENFORCE_INCREASING = False  # Enforce non-decreasing expected success probability\n",
    "    MAX_ROUNDS = None  # or an int\n",
    "    OUTPUT_DIR = Path(\"../output/visualizations/llm_bar\")\n",
    "    task_name = \"LLMBar\"\n",
    "    run_visualization(\n",
    "        answers_csv_path=Path(\"../output/llm_bar/processed_data.csv\"),\n",
    "        debates_csv_path=Path(\n",
    "            \"../data/llm_bar/Llama-3_1-8B-Instruct(11)/debate_rounds.csv\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        max_rounds=MAX_ROUNDS,\n",
    "        fitting_method=FIT_METHOD,\n",
    "        n_restarts=N_RESTARTS,\n",
    "        verbose=True,\n",
    "        enforce_increasing_success=ENFORCE_INCREASING,\n",
    "        extract_func=extract_1_2_answer,\n",
    "        compare_func=compare_llm_bar_response,\n",
    "        model_config=\"Llama-3_1-8B-Instruct(11)\",\n",
    "        row_number=2,\n",
    "        task_name=task_name,\n",
    "    )\n",
    "    run_visualization(\n",
    "        answers_csv_path=Path(\"../output/llm_bar/processed_data.csv\"),\n",
    "        debates_csv_path=Path(\n",
    "            \"../data/llm_bar/Qwen2_5-7B-Instruct(11)/debate_rounds.csv\"\n",
    "        ),\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        max_rounds=MAX_ROUNDS,\n",
    "        fitting_method=FIT_METHOD,\n",
    "        n_restarts=N_RESTARTS,\n",
    "        verbose=True,\n",
    "        enforce_increasing_success=ENFORCE_INCREASING,\n",
    "        extract_func=extract_1_2_answer,\n",
    "        compare_func=compare_llm_bar_response,\n",
    "        model_config=\"Qwen2_5-7B-Instruct(11)\",\n",
    "        row_number=2,\n",
    "        task_name=task_name,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
