{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "facebf91",
   "metadata": {},
   "source": [
    "# Prediction Disagreement Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f675b6d9",
   "metadata": {},
   "source": [
    "## Moral Except QA Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "34c0c2ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "import os\n",
    "import re\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy.cluster.hierarchy import linkage, dendrogram\n",
    "from sklearn.manifold import MDS\n",
    "from scipy.spatial.distance import squareform"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "120e5af1",
   "metadata": {},
   "source": [
    "##### Data Prep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "59145e97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# file mappings\n",
    "\n",
    "## llama31\n",
    "## llama32\n",
    "## mistral\n",
    "#______________\n",
    "## phi\n",
    "## olmo7b\n",
    "## qwen\n",
    "## deepseek\n",
    "#______________\n",
    "## olmo32b\n",
    "\n",
    "\n",
    "llama31_files = {\n",
    "    \"English\": \"binary_eval_results_scenario_meta-llama_Llama-3_1-8B-Instruct.json\",\n",
    "    \"Urdu\": \"binary_eval_results_Urdu_meta-llama_Llama-3_1-8B-Instruct.json\",\n",
    "    \"Chinese\": \"binary_eval_results_Chinese_meta-llama_Llama-3_1-8B-Instruct.json\",\n",
    "    \"Hindi\": \"binary_eval_results_Hindi_meta-llama_Llama-3_1-8B-Instruct.json\",\n",
    "    \"Spanish\": \"binary_eval_results_Spanish_meta-llama_Llama-3_1-8B-Instruct.json\",\n",
    "    \"German\": \"binary_eval_results_German_meta-llama_Llama-3_1-8B-Instruct.json\"\n",
    "}\n",
    "\n",
    "llama32_files = {\n",
    "    \"English\": \"binary_eval_results_scenario_meta-llama_Llama-3_2-3B-Instruct.json\",\n",
    "    \"Urdu\": \"binary_eval_results_Urdu_meta-llama_Llama-3_2-3B-Instruct.json\",\n",
    "    \"Chinese\": \"binary_eval_results_Chinese_meta-llama_Llama-3_2-3B-Instruct.json\",\n",
    "    \"Hindi\": \"binary_eval_results_Hindi_meta-llama_Llama-3_2-3B-Instruct.json\",\n",
    "    \"Spanish\": \"binary_eval_results_Spanish_meta-llama_Llama-3_2-3B-Instruct.json\",\n",
    "    \"German\": \"binary_eval_results_German_meta-llama_Llama-3_2-3B-Instruct.json\"\n",
    "}\n",
    "\n",
    "mistral_files = {\n",
    "    \"English\": \"binary_eval_results_scenario_mistralai_Mistral-7B-Instruct-v0_3.json\",\n",
    "    \"Urdu\": \"binary_eval_results_Urdu_mistralai_Mistral-7B-Instruct-v0_3.json\",\n",
    "    \"Chinese\": \"binary_eval_results_Chinese_mistralai_Mistral-7B-Instruct-v0_3.json\",\n",
    "    \"Hindi\": \"binary_eval_results_Hindi_mistralai_Mistral-7B-Instruct-v0_3.json\",\n",
    "    \"Spanish\": \"binary_eval_results_Spanish_mistralai_Mistral-7B-Instruct-v0_3.json\",\n",
    "    \"German\": \"binary_eval_results_German_mistralai_Mistral-7B-Instruct-v0_3.json\"\n",
    "}\n",
    "\n",
    "phi_files = {\n",
    "    \"English\": \"binary_eval_results_scenario_microsoft_Phi-4-mini-instruct.json\",\n",
    "    \"Urdu\": \"binary_eval_results_Urdu_microsoft_Phi-4-mini-instruct.json\",\n",
    "    \"Chinese\": \"binary_eval_results_Chinese_microsoft_Phi-4-mini-instruct.json\",\n",
    "    \"Hindi\": \"binary_eval_results_Hindi_microsoft_Phi-4-mini-instruct.json\",\n",
    "    \"Spanish\": \"binary_eval_results_Spanish_microsoft_Phi-4-mini-instruct.json\",\n",
    "    \"German\": \"binary_eval_results_German_microsoft_Phi-4-mini-instruct.json\"\n",
    "}\n",
    "\n",
    "olmo7b_files = {\n",
    "    \"English\": \"binary_eval_results_scenario_allenai_OLMo-7B-Instruct.json\",\n",
    "    \"Urdu\": \"binary_eval_results_Urdu_allenai_OLMo-7B-Instruct.json\",\n",
    "    \"Chinese\": \"binary_eval_results_Chinese_allenai_OLMo-7B-Instruct.json\",\n",
    "    \"Hindi\": \"binary_eval_results_Hindi_allenai_OLMo-7B-Instruct.json\",\n",
    "    \"Spanish\": \"binary_eval_results_Spanish_allenai_OLMo-7B-Instruct.json\",\n",
    "    \"German\": \"binary_eval_results_German_allenai_OLMo-7B-Instruct.json\"\n",
    "}\n",
    "\n",
    "qwen_files = {\n",
    "    \"English\": \"binary_eval_results_scenario_Qwen_Qwen2_5-7B-Instruct.json\",\n",
    "    \"Urdu\": \"binary_eval_results_Urdu_Qwen_Qwen2_5-7B-Instruct.json\",\n",
    "    \"Chinese\": \"binary_eval_results_Chinese_Qwen_Qwen2_5-7B-Instruct.json\",\n",
    "    \"Hindi\": \"binary_eval_results_Hindi_Qwen_Qwen2_5-7B-Instruct.json\",\n",
    "    \"Spanish\": \"binary_eval_results_Spanish_Qwen_Qwen2_5-7B-Instruct.json\",\n",
    "    \"German\": \"binary_eval_results_German_Qwen_Qwen2_5-7B-Instruct.json\"\n",
    "}\n",
    "\n",
    "deepseek_files = {\n",
    "    \"English\": \"binary_eval_results_scenario_deepseek-ai_DeepSeek-R1-Distill-Llama-8B.json\",\n",
    "    \"Urdu\": \"parsed_binary_eval_results_Urdu_deepseek-ai_DeepSeek-R1-Distill-Llama-8B.json\",\n",
    "    \"Chinese\": \"parsed_binary_eval_results_Chinese_deepseek-ai_DeepSeek-R1-Distill-Llama-8B.json\",\n",
    "    \"Hindi\": \"parsed_binary_eval_results_Hindi_deepseek-ai_DeepSeek-R1-Distill-Llama-8B.json\",\n",
    "    \"Spanish\": \"binary_eval_results_Spanish_deepseek-ai_DeepSeek-R1-Distill-Llama-8B.json\",\n",
    "    \"German\": \"binary_eval_results_German_deepseek-ai_DeepSeek-R1-Distill-Llama-8B.json\"\n",
    "}\n",
    "\n",
    "olmo32b_files = {\n",
    "    \"English\": \"binary_eval_results_scenario_allenai_OLMo-2-0325-32B-Instruct.json\",\n",
    "    \"Urdu\": \"binary_eval_results_Urdu_allenai_OLMo-2-0325-32B-Instruct.json\",\n",
    "    \"Chinese\": \"binary_eval_results_Chinese_allenai_OLMo-2-0325-32B-Instruct.json\",\n",
    "    \"Hindi\": \"binary_eval_results_Hindi_allenai_OLMo-2-0325-32B-Instruct.json\",\n",
    "    \"Spanish\": \"binary_eval_results_Spanish_allenai_OLMo-2-0325-32B-Instruct.json\",\n",
    "    \"German\": \"binary_eval_results_German_allenai_OLMo-2-0325-32B-Instruct.json\"\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a6e32579",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_csv(file_map, output_csv, base_dir=None):\n",
    "    df_all = pd.DataFrame()\n",
    "    \n",
    "    for lang_code, path in file_map.items():\n",
    "        full_path = os.path.join(base_dir, path) if base_dir else path\n",
    "        df = pd.read_json(full_path)\n",
    "\n",
    "        # Dynamically determine the correct question column\n",
    "        if \"scenario\" in df.columns:\n",
    "            question_key = \"scenario\"\n",
    "        elif \"question\" in df.columns:\n",
    "            question_key = \"question\"\n",
    "        else:\n",
    "            raise ValueError(f\"Neither 'scenario' nor 'question' found in {full_path}\")\n",
    "\n",
    "        if df_all.empty:\n",
    "            df_all[\"question\"] = df[question_key]\n",
    "            df_all[\"reference\"] = df[\"reference\"]\n",
    "\n",
    "        df_all[lang_code] = df[\"parsed_answer\"]\n",
    "    \n",
    "    df_all.to_csv(output_csv, index=False, encoding='utf-8')\n",
    "    print(f\"Saved: {output_csv} (rows: {len(df_all)})\")\n",
    "    return df_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "95f015e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: MEQ_combined_results/meq_llama31_predictions_combined.csv (rows: 148)\n",
      "Saved: MEQ_combined_results/meq_llama32_predictions_combined.csv (rows: 148)\n",
      "Saved: MEQ_combined_results/meq_mistral_predictions_combined.csv (rows: 148)\n",
      "Saved: MEQ_combined_results/meq_phi_predictions_combined.csv (rows: 148)\n",
      "Saved: MEQ_combined_results/meq_olmo7b_predictions_combined.csv (rows: 148)\n",
      "Saved: MEQ_combined_results/meq_qwen_predictions_combined.csv (rows: 148)\n",
      "Saved: MEQ_combined_results/meq_deepseek_predictions_combined.csv (rows: 148)\n",
      "Saved: MEQ_combined_results/meq_olmo32b_predictions_combined.csv (rows: 148)\n"
     ]
    }
   ],
   "source": [
    "# Make sure the output directory exists\n",
    "combined_output_dir = \"MEQ_combined_results\"\n",
    "os.makedirs(combined_output_dir, exist_ok=True)\n",
    "\n",
    "model_file_maps = {\n",
    "    \"llama31\": llama31_files,\n",
    "    \"llama32\": llama32_files,\n",
    "    \"mistral\": mistral_files,\n",
    "    \"phi\": phi_files,\n",
    "    \"olmo7b\": olmo7b_files,\n",
    "    \"qwen\": qwen_files,\n",
    "    \"deepseek\": deepseek_files,\n",
    "    \"olmo32b\": olmo32b_files\n",
    "}\n",
    "\n",
    "# get the combined files\n",
    "for model_name, file_map in model_file_maps.items():\n",
    "    output_csv = os.path.join(combined_output_dir, f\"meq_{model_name}_predictions_combined.csv\")\n",
    "\n",
    "    if model_name == \"olmo32b\":\n",
    "        build_csv(file_map, output_csv, base_dir = \"MEQ_OLMo2_Result\")\n",
    "    else:\n",
    "        build_csv(file_map, output_csv, base_dir = \"MEQ_Results\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e45e51f",
   "metadata": {},
   "source": [
    "##### Prediction Disagreement Matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "0436c78e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# creates matrix\n",
    "\n",
    "def disagreement_matrix_symmetric(df, pred_cols):\n",
    "    matrix = pd.DataFrame(index=pred_cols, columns=pred_cols, dtype=int)\n",
    "\n",
    "    for i in pred_cols:\n",
    "        for j in pred_cols:\n",
    "            if i == j:\n",
    "                matrix.loc[i, j] = 0\n",
    "            else:\n",
    "                matrix.loc[i, j] = (df[i] != df[j]).sum()\n",
    "    \n",
    "    return matrix\n",
    "\n",
    "# creates visual\n",
    "def plot_disagreement_heatmap(disagreement_df, title=\"Prediction Disagreement Matrix\", save_path=None, show=False):\n",
    "    plt.figure(figsize=(6, 6))\n",
    "    sns.heatmap(\n",
    "        disagreement_df.astype(int), \n",
    "        annot=True, \n",
    "        fmt=\"d\", \n",
    "        cmap=\"YlGnBu\", \n",
    "        linewidths=0.5, \n",
    "        square=True,\n",
    "        cbar_kws={\"label\": \"Disagreement Count\"}\n",
    "    )\n",
    "    plt.title(title, fontsize=14)\n",
    "    plt.xticks(rotation=45)\n",
    "    plt.yticks(rotation=0)\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, bbox_inches='tight')\n",
    "\n",
    "    if show:\n",
    "        plt.show()\n",
    "\n",
    "    plt.close()\n",
    "\n",
    "cols = [\"English\", \"Urdu\", \"Chinese\", \"Hindi\", \"Spanish\", \"German\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "ee61148a",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"disagreement_plots/MEQ\", exist_ok=True)\n",
    "\n",
    "combined_dir = \"MEQ_combined_results\"\n",
    "all_matrices = {}\n",
    "\n",
    "# List all relevant combined CSVs\n",
    "for fname in os.listdir(combined_dir):\n",
    "    if not fname.endswith(\"_predictions_combined.csv\"):\n",
    "        continue\n",
    "    \n",
    "    path = os.path.join(combined_dir, fname)\n",
    "    df = pd.read_csv(path)\n",
    "\n",
    "    lang_cols = [col for col in df.columns if col not in [\"question\", \"reference\"]]\n",
    "    if not lang_cols:\n",
    "        continue\n",
    "\n",
    "    mat = disagreement_matrix_symmetric(df, lang_cols)\n",
    "    mat_df = pd.DataFrame(mat, index=lang_cols, columns=lang_cols)\n",
    "\n",
    "    base = fname.replace(\"_predictions_combined.csv\", \"\")\n",
    "    parts = base.split(\"_\")\n",
    "    paradigm = parts[0]\n",
    "    model = \"_\".join(parts[1:])\n",
    "\n",
    "    if paradigm not in all_matrices:\n",
    "        all_matrices[paradigm] = {}\n",
    "    all_matrices[paradigm][model] = mat_df\n",
    "\n",
    "    # Save the heatmap directly\n",
    "    plot_path = f\"disagreement_plots/MEQ/{base}_heatmap.png\"\n",
    "    plot_disagreement_heatmap(\n",
    "        mat_df, \n",
    "        title=f'PDM - {base}',\n",
    "        save_path=plot_path,\n",
    "        show=False  # set to True if you want interactive viewing\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbf2067b",
   "metadata": {},
   "source": [
    "##### Clustering and Patterns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea53abe5",
   "metadata": {},
   "source": [
    "**Multidimensional Scaling (MDS)**\n",
    "\n",
    "MDS plots for our work would visualise how similarly different language-specific LLM predictions behave by mapping their pairwise disagreement distances into 2D space - it highlights which languages align in model behavior and which ones diverge, offering insight into cross-lingual consistency and reliability.\n",
    "\n",
    "Distance matrix = disagreement matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7e770d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mds(disagreement_df, title=\"MDS of Prediction Disagreement\", save_path=None, show=False):\n",
    "    mds = MDS(dissimilarity=\"precomputed\", random_state=42)\n",
    "    coords = mds.fit_transform(disagreement_df.values)\n",
    "    df_mds = pd.DataFrame(coords, index=disagreement_df.index, columns=[\"x\", \"y\"])\n",
    "    \n",
    "    plt.figure(figsize=(6, 5))\n",
    "    sns.scatterplot(data=df_mds, x=\"x\", y=\"y\", hue=df_mds.index, s=100, legend=False)\n",
    "    \n",
    "    for label, row in df_mds.iterrows():\n",
    "        plt.text(row.x + 0.01, row.y + 0.01, label)\n",
    "    \n",
    "    plt.title(title)\n",
    "    plt.axis(\"equal\")\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, bbox_inches='tight')\n",
    "\n",
    "    if show:\n",
    "        plt.show()\n",
    "    \n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f3a98584",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"disagreement_plots/MEQ\", exist_ok=True)\n",
    "\n",
    "combined_dir = \"MEQ_combined_results\"\n",
    "\n",
    "for fname in os.listdir(combined_dir):\n",
    "    if not fname.endswith(\"_predictions_combined.csv\"):\n",
    "        continue\n",
    "\n",
    "    path = os.path.join(combined_dir, fname)\n",
    "    df = pd.read_csv(path)\n",
    "\n",
    "    lang_cols = [col for col in df.columns if col not in [\"question\", \"reference\"]]\n",
    "    if not lang_cols:\n",
    "        continue\n",
    "\n",
    "    mat = disagreement_matrix_symmetric(df, lang_cols)\n",
    "    mat_df = pd.DataFrame(mat, index=lang_cols, columns=lang_cols)\n",
    "\n",
    "    plot_path = f\"disagreement_plots/MEQ/{fname.replace('_predictions_combined.csv', '')}_mds.png\"\n",
    "\n",
    "    plot_mds(\n",
    "        mat_df, \n",
    "        title=f'MDS - {fname.replace(\"_predictions_combined.csv\", \"\")}',\n",
    "        save_path=plot_path,\n",
    "        show=False\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8339302f",
   "metadata": {},
   "source": [
    "Takeaways:\n",
    "\n",
    "- Potential relatively higher-resource alignment: English, German, and Chinese often form tight cores.\n",
    "\n",
    "- Hindi and Urdu consistently break from the cluster which may indicate: cultural framing mismatch, poor translation quality, or less consistent reasoning behavior in LLMs\n",
    "\n",
    "- Model behavior shifts: LLaMA 3.2 causes English to behave differently from 3.1, and Mistral reinforces Urdu/Hindi as the least aligned.-"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71304b0d",
   "metadata": {},
   "source": [
    "**Hierarchical Clustering (Agglomerative)**\n",
    "\n",
    "Another way to visualise which languages behave similarly - this time using a dendrogram."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a666cfce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_dendrogram(disagreement_df, title=\"Language Clustering (Disagreement)\", save_path=None, show=False):\n",
    "    # Convert to condensed distance matrix\n",
    "    condensed = squareform(disagreement_df.values)\n",
    "    linkage_matrix = linkage(condensed, method=\"ward\")\n",
    "    \n",
    "    plt.figure(figsize=(8, 5))\n",
    "    dendrogram(linkage_matrix, labels=disagreement_df.index.tolist())\n",
    "    plt.title(title)\n",
    "    plt.ylabel(\"Distance\")\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, bbox_inches='tight')\n",
    "    \n",
    "    if show:\n",
    "        plt.show()\n",
    "    \n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ac687170",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"disagreement_plots/MEQ\", exist_ok=True)\n",
    "\n",
    "combined_dir = \"MEQ_combined_results\"\n",
    "\n",
    "for fname in os.listdir(combined_dir):\n",
    "    if not fname.endswith(\"_predictions_combined.csv\"):\n",
    "        continue\n",
    "\n",
    "    path = os.path.join(combined_dir, fname)\n",
    "    df = pd.read_csv(path)\n",
    "\n",
    "    lang_cols = [col for col in df.columns if col not in [\"question\", \"reference\"]]\n",
    "    if not lang_cols:\n",
    "        continue\n",
    "\n",
    "    mat = disagreement_matrix_symmetric(df, lang_cols)\n",
    "    mat_df = pd.DataFrame(mat, index=lang_cols, columns=lang_cols)\n",
    "\n",
    "    plot_path = f\"disagreement_plots/MEQ/{fname.replace('_predictions_combined.csv', '')}_dendrogram.png\"\n",
    "    \n",
    "    plot_dendrogram(\n",
    "        mat_df, \n",
    "        title=f'Dendrogram - {fname.replace(\"_predictions_combined.csv\", \"\")}',\n",
    "        save_path=plot_path,\n",
    "        show=False  # Set to True if you want to display as well\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1418735f",
   "metadata": {},
   "source": [
    "Takeaways: \n",
    "\n",
    "- Consistent Clustering: English, German, Chinese, and Spanish form a relatively stable agreement group across models.\n",
    "\n",
    "- Outlier Behavior: Urdu and Hindi consistently behave as outliers — either due to translation variation, low-resource alignment, or cultural framing issues in LLMs.\n",
    "\n",
    "- Model Differences: LLaMA 3.1 seems more stable than 3.2 in handling English. Mistral appears to treat Urdu and Hindi as significantly more divergent than LLaMA models do.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07526747",
   "metadata": {},
   "source": [
    "### Consolidated MEQ PLOTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "49edaa39",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"disagreement_plots/MEQ\", exist_ok=True)\n",
    "\n",
    "for paradigm, model_dict in all_matrices.items():\n",
    "    mats = list(model_dict.values())\n",
    "    if not mats:\n",
    "        continue\n",
    "    \n",
    "    # Find common languages\n",
    "    lang_sets = [set(m.index) for m in mats]\n",
    "    common_langs = set.intersection(*lang_sets)\n",
    "    if not common_langs:\n",
    "        continue\n",
    "    common_langs = sorted(list(common_langs))\n",
    "\n",
    "    # Restrict all matrices to common languages\n",
    "    mats_common = [m.loc[common_langs, common_langs] for m in mats]\n",
    "    \n",
    "    # Compute mean disagreement matrix\n",
    "    mean_mat = np.mean([m.values for m in mats_common], axis=0)\n",
    "    mean_df = pd.DataFrame(mean_mat, index=common_langs, columns=common_langs)\n",
    "\n",
    "    # Plot and save heatmap\n",
    "    plot_disagreement_heatmap(\n",
    "        mean_df, \n",
    "        title=f\"Consolidated Disagreement - {paradigm}\",\n",
    "        save_path=f\"disagreement_plots/MEQ/consolidated_{paradigm}_matrix.png\"\n",
    "    )\n",
    "\n",
    "    # Plot and save MDS\n",
    "    plot_mds(\n",
    "        mean_df, \n",
    "        title=f\"Consolidated MDS - {paradigm}\",\n",
    "        save_path=f\"disagreement_plots/MEQ/consolidated_{paradigm}_mds.png\"\n",
    "    )\n",
    "\n",
    "    # Plot and save dendrogram\n",
    "    plot_dendrogram(\n",
    "        mean_df, \n",
    "        title=f\"Consolidated Dendrogram - {paradigm}\",\n",
    "        save_path=f\"disagreement_plots/MEQ/consolidated_{paradigm}_dendrogram.png\"\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a94176f",
   "metadata": {},
   "source": [
    "## ETHICS Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6146a53",
   "metadata": {},
   "source": [
    "#### Data Prep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86f3c354",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Spanish', 'German', 'Hindi', 'Chinese', 'English', 'Urdu']\n"
     ]
    }
   ],
   "source": [
    "ethics_file_map = {}\n",
    "\n",
    "def add_to_map(paradigm, model, language, path):\n",
    "    if paradigm not in ethics_file_map:\n",
    "        ethics_file_map[paradigm] = {}\n",
    "    if model not in ethics_file_map[paradigm]:\n",
    "        ethics_file_map[paradigm][model] = {}\n",
    "    ethics_file_map[paradigm][model][language] = path\n",
    "\n",
    "if os.path.isdir('ETHICS_Results_sf'):\n",
    "    for fname in os.listdir('ETHICS_Results_sf'):\n",
    "        m = re.match(r'para_([a-zA-Z_]+)_([a-zA-Z]+)_results_(.+)\\.json', fname)\n",
    "        if m:\n",
    "            paradigm, language, model = m.groups()\n",
    "            add_to_map(paradigm, model, language, f'ETHICS_Results_sf/{fname}')\n",
    "\n",
    "if os.path.isdir('ETHICS_Results_j'):\n",
    "    for fname in os.listdir('ETHICS_Results_j'):\n",
    "        m = re.match(r'(fixed_)?para_([a-zA-Z_]+)_([a-zA-Z]+)_results_(.+)\\.json', fname)\n",
    "        if m:\n",
    "            paradigm, language, model = m.group(2), m.group(3), m.group(4)\n",
    "            add_to_map(paradigm, model, language, f'ETHICS_Results_j/{fname}')\n",
    "\n",
    "languages = [\"English\", \"Chinese\", \"Urdu\", \"Hindi\", \"Spanish\", \"German\"]\n",
    "lang_regex = \"|\".join(languages)\n",
    "if os.path.isdir('ETHICS_Results_z'):\n",
    "    for fname in os.listdir('ETHICS_Results_z'):\n",
    "        m = re.match(rf'ethics_([a-zA-Z_]+)_({lang_regex})_(.+)\\.json', fname)\n",
    "        if m:\n",
    "            paradigm, language, model = m.groups()\n",
    "            add_to_map(paradigm, model, language, f'ETHICS_Results_z/{fname}')\n",
    "\n",
    "# Example usage:\n",
    "# ethics_file_map[\"den\"][\"meta_llama_Llama_3_1_8B_Instruct\"][\"English\"]\n",
    "# ethics_file_map[\"den\"][\"allenai_OLMo_2_0325_32B_Instruct\"][\"English\"]\n",
    "# ethics_file_map[\"virtue\"][\"deepseek-ai_DeepSeek-R1-Distill-Llama-8B\"][\"Urdu\"]\n",
    "# print(list(ethics_file_map[\"cms\"][\"Qwen_Qwen2_5-7B-Instruct\"].keys()))\n",
    "# print(list(ethics_file_map[\"cms\"].keys()))\n",
    "print(list(ethics_file_map[\"cms\"][\"Qwen_Qwen2_5-7B-Instruct\"].keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1802ac82",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: ETHICS_combined_results/virtue_meta_llama_Llama_3_1_8B_Instruct_predictions_combined.csv (rows: 4975)\n",
      "Saved: ETHICS_combined_results/virtue_meta_llama_Llama_3_2_3B_Instruct_predictions_combined.csv (rows: 4975)\n",
      "Saved: ETHICS_combined_results/virtue_microsoft_Phi_4_mini_instruct_predictions_combined.csv (rows: 4975)\n",
      "Saved: ETHICS_combined_results/virtue_mistralai_Mistral_7B_Instruct_v0_3_predictions_combined.csv (rows: 4975)\n",
      "Saved: ETHICS_combined_results/virtue_allenai_OLMo_2_0325_32B_Instruct_predictions_combined.csv (rows: 4975)\n",
      "Saved: ETHICS_combined_results/virtue_allenai_OLMo_7B_Instruct_predictions_combined.csv (rows: 4975)\n",
      "Saved: ETHICS_combined_results/virtue_deepseek-ai_DeepSeek-R1-Distill-Llama-8B_predictions_combined.csv (rows: 4976)\n",
      "Saved: ETHICS_combined_results/virtue_Qwen_Qwen2_5-7B-Instruct_predictions_combined.csv (rows: 4976)\n",
      "Saved: ETHICS_combined_results/justice_meta_llama_Llama_3_2_3B_Instruct_predictions_combined.csv (rows: 2704)\n",
      "Saved: ETHICS_combined_results/justice_meta_llama_Llama_3_1_8B_Instruct_predictions_combined.csv (rows: 2704)\n",
      "Saved: ETHICS_combined_results/justice_microsoft_Phi_4_mini_instruct_predictions_combined.csv (rows: 2704)\n",
      "Saved: ETHICS_combined_results/justice_allenai_OLMo_2_0325_32B_Instruct_predictions_combined.csv (rows: 2704)\n",
      "Saved: ETHICS_combined_results/justice_allenai_OLMo_7B_Instruct_predictions_combined.csv (rows: 2704)\n",
      "Saved: ETHICS_combined_results/justice_mistralai_Mistral_7B_Instruct_v0_3_predictions_combined.csv (rows: 2704)\n",
      "Saved: ETHICS_combined_results/cms_meta_llama_Llama_3_2_3B_Instruct_predictions_combined.csv (rows: 3885)\n",
      "Saved: ETHICS_combined_results/cms_meta_llama_Llama_3_1_8B_Instruct_predictions_combined.csv (rows: 3885)\n",
      "Saved: ETHICS_combined_results/cms_microsoft_Phi_4_mini_instruct_predictions_combined.csv (rows: 3885)\n",
      "Saved: ETHICS_combined_results/cms_allenai_OLMo_7B_Instruct_predictions_combined.csv (rows: 3885)\n",
      "Saved: ETHICS_combined_results/cms_allenai_OLMo_2_0325_32B_Instruct_predictions_combined.csv (rows: 3885)\n",
      "Saved: ETHICS_combined_results/cms_mistralai_Mistral_7B_Instruct_v0_3_predictions_combined.csv (rows: 3885)\n",
      "Saved: ETHICS_combined_results/cms_Qwen_Qwen2_5-7B-Instruct_predictions_combined.csv (rows: 3886)\n",
      "Saved: ETHICS_combined_results/cms_deepseek-ai_DeepSeek-R1-Distill-Llama-8B_predictions_combined.csv (rows: 3886)\n",
      "Saved: ETHICS_combined_results/util_meta_llama_Llama_3_1_8B_Instruct_predictions_combined.csv (rows: 4808)\n",
      "Saved: ETHICS_combined_results/util_meta_llama_Llama_3_2_3B_Instruct_predictions_combined.csv (rows: 4808)\n",
      "Saved: ETHICS_combined_results/util_microsoft_Phi_4_mini_instruct_predictions_combined.csv (rows: 4808)\n",
      "Saved: ETHICS_combined_results/util_allenai_OLMo_2_0325_32B_Instruct_predictions_combined.csv (rows: 4808)\n",
      "Saved: ETHICS_combined_results/util_allenai_OLMo_7B_Instruct_predictions_combined.csv (rows: 4808)\n",
      "Saved: ETHICS_combined_results/util_mistralai_Mistral_7B_Instruct_v0_3_predictions_combined.csv (rows: 4808)\n",
      "Saved: ETHICS_combined_results/util_deepseek-ai_DeepSeek-R1-Distill-Llama-8B_predictions_combined.csv (rows: 4808)\n",
      "Saved: ETHICS_combined_results/util_Qwen_Qwen2_5-7B-Instruct_predictions_combined.csv (rows: 4809)\n",
      "Saved: ETHICS_combined_results/den_meta_llama_Llama_3_1_8B_Instruct_predictions_combined.csv (rows: 3596)\n",
      "Saved: ETHICS_combined_results/den_meta_llama_Llama_3_2_3B_Instruct_predictions_combined.csv (rows: 3596)\n",
      "Saved: ETHICS_combined_results/den_microsoft_Phi_4_mini_instruct_predictions_combined.csv (rows: 3596)\n",
      "Saved: ETHICS_combined_results/den_mistralai_Mistral_7B_Instruct_v0_3_predictions_combined.csv (rows: 3596)\n",
      "Saved: ETHICS_combined_results/den_allenai_OLMo_2_0325_32B_Instruct_predictions_combined.csv (rows: 3596)\n",
      "Saved: ETHICS_combined_results/den_allenai_OLMo_7B_Instruct_predictions_combined.csv (rows: 3596)\n",
      "Saved: ETHICS_combined_results/den_deepseek-ai_DeepSeek-R1-Distill-Llama-8B_predictions_combined.csv (rows: 3597)\n",
      "Saved: ETHICS_combined_results/den_Qwen_Qwen2_5-7B-Instruct_predictions_combined.csv (rows: 3597)\n",
      "Saved: ETHICS_combined_results/justice_scenario_deepseek-ai_DeepSeek-R1-Distill-Llama-8B_predictions_combined.csv (rows: 2705)\n",
      "Saved: ETHICS_combined_results/justice_scenario_Qwen_Qwen2_5-7B-Instruct_predictions_combined.csv (rows: 2705)\n"
     ]
    }
   ],
   "source": [
    "def robust_read_ethics_json(path):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        data = json.load(f)\n",
    "    # If the file is a dict with a \"results\" key, use that\n",
    "    if isinstance(data, dict) and \"results\" in data:\n",
    "        return pd.DataFrame(data[\"results\"])\n",
    "    # If it's already a list, just use it\n",
    "    elif isinstance(data, list):\n",
    "        return pd.DataFrame(data)\n",
    "    # If it's a dict of lists, try to convert\n",
    "    elif isinstance(data, dict):\n",
    "        return pd.DataFrame(data)\n",
    "    else:\n",
    "        raise ValueError(f\"Unrecognized JSON structure in {path}\")\n",
    "\n",
    "def build_ethics_csv(paradigm, model, file_map, output_csv):\n",
    "    df_all = pd.DataFrame()\n",
    "    for lang, path in file_map[paradigm][model].items():\n",
    "        df = robust_read_ethics_json(path)\n",
    "        # Try to find the question key\n",
    "        question_key = None\n",
    "        for candidate in [\"scenario\", \"question\"]:\n",
    "            if candidate in df.columns:\n",
    "                question_key = candidate\n",
    "                break\n",
    "        if question_key is None:\n",
    "            raise ValueError(f\"No question/scenario column in {path}\")\n",
    "        if df_all.empty:\n",
    "            df_all[\"question\"] = df[question_key]\n",
    "            # Reference logic\n",
    "            if \"reference\" in df.columns:\n",
    "                df_all[\"reference\"] = df[\"reference\"]\n",
    "            elif \"util\" in path:\n",
    "                df_all[\"reference\"] = 1\n",
    "        df_all[lang] = df[\"parsed_answer\"]\n",
    "    df_all.to_csv(output_csv, index=False, encoding='utf-8')\n",
    "    print(f\"Saved: {output_csv} (rows: {len(df_all)})\")\n",
    "    return df_all\n",
    "\n",
    "# Create output directory if it doesn't exist\n",
    "os.makedirs(\"ETHICS_combined_results\", exist_ok=True)\n",
    "\n",
    "# Save combined CSVs for all paradigm/model pairs (except Qwen and DeepSeek models)\n",
    "for paradigm in ethics_file_map:\n",
    "    for model in ethics_file_map[paradigm]:\n",
    "        # if (\"qwen\" in model.lower()) or (\"deepseek\" in model.lower()):\n",
    "        #     continue\n",
    "        output_csv = f\"ETHICS_combined_results/{paradigm}_{model}_predictions_combined.csv\"\n",
    "        build_ethics_csv(paradigm, model, ethics_file_map, output_csv)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50c78538",
   "metadata": {},
   "source": [
    "#### Prediction Disagreement Matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "2062bbeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"disagreement_plots\", exist_ok=True)\n",
    "\n",
    "combined_dir = \"ETHICS_combined_results\"\n",
    "\n",
    "\n",
    "# List all relevant combined CSVs\n",
    "for fname in os.listdir(combined_dir):\n",
    "    if not fname.endswith(\"_predictions_combined.csv\"):\n",
    "        continue\n",
    "    \n",
    "    path = os.path.join(combined_dir, fname)\n",
    "    df = pd.read_csv(path)\n",
    "\n",
    "    lang_cols = [col for col in df.columns if col not in [\"question\", \"reference\"]]\n",
    "    if not lang_cols:\n",
    "        continue\n",
    "\n",
    "    mat = disagreement_matrix_symmetric(df, lang_cols)\n",
    "    mat_df = pd.DataFrame(mat, index=lang_cols, columns=lang_cols)\n",
    "\n",
    "    base = fname.replace(\"_predictions_combined.csv\", \"\")\n",
    "    parts = base.split(\"_\")\n",
    "    paradigm = parts[0]\n",
    "    model = \"_\".join(parts[1:])\n",
    "\n",
    "    if paradigm not in all_matrices:\n",
    "        all_matrices[paradigm] = {}\n",
    "    all_matrices[paradigm][model] = mat_df\n",
    "\n",
    "    # Save the heatmap directly\n",
    "    plot_path = f\"disagreement_plots/{base}_heatmap.png\"\n",
    "    plot_disagreement_heatmap(\n",
    "        mat_df, \n",
    "        title=f'PDM - {base}',\n",
    "        save_path=plot_path,\n",
    "        show=False  # set to True if you want interactive viewing\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "867cb458",
   "metadata": {},
   "source": [
    "#### Clustering and Patterns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99bb374c",
   "metadata": {},
   "source": [
    "Multi-dim scaling plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "75cff316",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"disagreement_plots\", exist_ok=True)\n",
    "\n",
    "combined_dir = \"ETHICS_combined_results\"\n",
    "\n",
    "for fname in os.listdir(combined_dir):\n",
    "    if not fname.endswith(\"_predictions_combined.csv\"):\n",
    "        continue\n",
    "\n",
    "    path = os.path.join(combined_dir, fname)\n",
    "    df = pd.read_csv(path)\n",
    "\n",
    "    lang_cols = [col for col in df.columns if col not in [\"question\", \"reference\"]]\n",
    "    if not lang_cols:\n",
    "        continue\n",
    "\n",
    "    mat = disagreement_matrix_symmetric(df, lang_cols)\n",
    "    mat_df = pd.DataFrame(mat, index=lang_cols, columns=lang_cols)\n",
    "\n",
    "    plot_path = f\"disagreement_plots/{fname.replace('_predictions_combined.csv', '')}_mds.png\"\n",
    "\n",
    "    plot_mds(\n",
    "        mat_df, \n",
    "        title=f'MDS - {fname.replace(\"_predictions_combined.csv\", \"\")}',\n",
    "        save_path=plot_path,\n",
    "        show=False  # set to True if you want to display\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82fcf45f",
   "metadata": {},
   "source": [
    "Hierarchical Clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "0656397b",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"disagreement_plots\", exist_ok=True)\n",
    "\n",
    "combined_dir = \"ETHICS_combined_results\"\n",
    "\n",
    "for fname in os.listdir(combined_dir):\n",
    "    if not fname.endswith(\"_predictions_combined.csv\"):\n",
    "        continue\n",
    "\n",
    "    path = os.path.join(combined_dir, fname)\n",
    "    df = pd.read_csv(path)\n",
    "\n",
    "    lang_cols = [col for col in df.columns if col not in [\"question\", \"reference\"]]\n",
    "    if not lang_cols:\n",
    "        continue\n",
    "\n",
    "    mat = disagreement_matrix_symmetric(df, lang_cols)\n",
    "    mat_df = pd.DataFrame(mat, index=lang_cols, columns=lang_cols)\n",
    "\n",
    "    plot_path = f\"disagreement_plots/{fname.replace('_predictions_combined.csv', '')}_dendrogram.png\"\n",
    "    \n",
    "    plot_dendrogram(\n",
    "        mat_df, \n",
    "        title=f'Dendrogram - {fname.replace(\"_predictions_combined.csv\", \"\")}',\n",
    "        save_path=plot_path,\n",
    "        show=False  # Set to True if you want to display as well\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32f9df6d",
   "metadata": {},
   "source": [
    "### Consolidated Ethics Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8ce25960",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"disagreement_plots\", exist_ok=True)\n",
    "\n",
    "for paradigm, model_dict in all_matrices.items():\n",
    "    mats = list(model_dict.values())\n",
    "    if not mats:\n",
    "        continue\n",
    "    \n",
    "    # Find common languages\n",
    "    lang_sets = [set(m.index) for m in mats]\n",
    "    common_langs = set.intersection(*lang_sets)\n",
    "    if not common_langs:\n",
    "        continue\n",
    "    common_langs = sorted(list(common_langs))\n",
    "\n",
    "    # Restrict all matrices to common languages\n",
    "    mats_common = [m.loc[common_langs, common_langs] for m in mats]\n",
    "    \n",
    "    # Compute mean disagreement matrix\n",
    "    mean_mat = np.mean([m.values for m in mats_common], axis=0)\n",
    "    mean_df = pd.DataFrame(mean_mat, index=common_langs, columns=common_langs)\n",
    "\n",
    "    # Plot and save heatmap\n",
    "    plot_disagreement_heatmap(\n",
    "        mean_df, \n",
    "        title=f\"Consolidated Disagreement - {paradigm}\",\n",
    "        save_path=f\"disagreement_plots/consolidated_{paradigm}_matrix.png\"\n",
    "    )\n",
    "\n",
    "    # Plot and save MDS\n",
    "    plot_mds(\n",
    "        mean_df, \n",
    "        title=f\"Consolidated MDS - {paradigm}\",\n",
    "        save_path=f\"disagreement_plots/consolidated_{paradigm}_mds.png\"\n",
    "    )\n",
    "\n",
    "    # Plot and save dendrogram\n",
    "    plot_dendrogram(\n",
    "        mean_df, \n",
    "        title=f\"Consolidated Dendrogram - {paradigm}\",\n",
    "        save_path=f\"disagreement_plots/consolidated_{paradigm}_dendrogram.png\"\n",
    "    )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
