{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os, sys\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy.stats import wilcoxon\n",
    "\n",
    "project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
    "if project_root not in sys.path:\n",
    "    sys.path.insert(0, project_root)\n",
    "\n",
    "from utils.variant_grouping import classify\n",
    "\n",
    "# === Pretraining Dataset Names ===\n",
    "dataset_name = 'Wikipedia'\n",
    "#dataset_name = 'BookCorpus'\n",
    "#dataset_name = 'Common Crawl'\n",
    "#dataset_name = 'RefinedWeb'\n",
    "#dataset_name = 'RedPajama'\n",
    "#dataset_name = 'Dolma'\n",
    "\n",
    "# === Load Files ===\n",
    "with open('../#Resources/AmE_BrE_variations.json', 'r', encoding='utf-8') as f:\n",
    "    variant_data = json.load(f)\n",
    "\n",
    "with open(f'./#Datasets/{dataset_name}/word_frequency.json', 'r', encoding='utf-8') as f:\n",
    "    word_freq = json.load(f)\n",
    "\n",
    "# === Enrich With Frequency + Group ===\n",
    "enriched = []\n",
    "for entry in variant_data:\n",
    "    us_word = entry[\"us\"]\n",
    "    uk_word = entry[\"uk\"]\n",
    "    _id = entry[\"_id\"]\n",
    "\n",
    "    us_freq = word_freq.get(us_word, 0)\n",
    "    uk_freq = word_freq.get(uk_word, 0)\n",
    "    total = us_freq + uk_freq\n",
    "    p_us = us_freq / total if total else None\n",
    "    p_uk = uk_freq / total if total else None\n",
    "\n",
    "    group, diff_type, category = classify(us_word, uk_word)\n",
    "\n",
    "    enriched.append({\n",
    "        \"_id\": _id,\n",
    "        \"us\": us_word,\n",
    "        \"uk\": uk_word,\n",
    "        \"us_freq\": us_freq,\n",
    "        \"uk_freq\": uk_freq,\n",
    "        \"total\": total,\n",
    "        \"p_us\": p_us,\n",
    "        \"p_uk\": p_uk,\n",
    "        \"group\": group,\n",
    "        \"type\": diff_type,\n",
    "        \"category\": category\n",
    "    })\n",
    "\n",
    "# === Save Output ===\n",
    "with open(f'./Results/{dataset_name}_us_uk_probabilities.json', 'w', encoding='utf-8') as f:\n",
    "    json.dump(enriched, f, indent=2)\n",
    "\n",
    "print(\"Corpus Name: \", dataset_name)\n",
    "print(\"-\" * 40)\n",
    "\n",
    "# === Wilcoxon Test ===\n",
    "df = pd.DataFrame(enriched).dropna(subset=[\"p_us\", \"p_uk\"])\n",
    "print(\"Wilcoxon Signed-Rank Test\")\n",
    "stat, p_value = wilcoxon(df[\"us_freq\"], df[\"uk_freq\"])\n",
    "print(f\"  Statistic: {stat:.4f}\")\n",
    "print(f\"  p-value: {p_value:.6f}\")\n",
    "print(\"-\" * 40)\n",
    "\n",
    "# === Print Average Probabilities per Category ===\n",
    "for cat in [\"Orthographic/Spelling\", \"Vocabulary\"]:\n",
    "    avg_us = df[df[\"category\"] == cat][\"p_us\"].mean()\n",
    "    avg_uk = df[df[\"category\"] == cat][\"p_uk\"].mean()\n",
    "    print(f\"Category: {cat}\")\n",
    "    print(f\"  Avg. AmE variant probability: {avg_us*100:.2f}\")\n",
    "    print(f\"  Avg. BrE variant probability: {avg_uk*100:.2f}\")\n",
    "    print(\"-\" * 40)\n",
    "\n",
    "# === Plotting Prep ===\n",
    "sns.set(style=\"whitegrid\", font_scale=1.3)\n",
    "df_long = pd.melt(\n",
    "    df,\n",
    "    id_vars=[\"group\", \"category\", \"type\"],\n",
    "    value_vars=[\"p_us\", \"p_uk\"],\n",
    "    var_name=\"variant\",\n",
    "    value_name=\"probability\"\n",
    ")\n",
    "df_long[\"variant\"] = df_long[\"variant\"].replace({\"p_us\": \"AmE\", \"p_uk\": \"BrE\"})\n",
    "\n",
    "# === Violin Plot ===\n",
    "plt.figure(figsize=(10, 6))\n",
    "sns.violinplot(\n",
    "    data=df_long,\n",
    "    x=\"category\",\n",
    "    y=\"probability\",\n",
    "    hue=\"variant\",\n",
    "    split=True,\n",
    "    inner=\"quartile\",\n",
    "    palette=\"Set2\"\n",
    ")\n",
    "plt.title(\"Distribution of Variant Probabilities by Category\", fontsize=20, weight='bold')\n",
    "plt.xlabel(\"Category\", fontsize=18, weight='bold')\n",
    "plt.ylabel(\"Variant Probability\", fontsize=18, weight='bold')\n",
    "plt.xticks(fontsize=16, weight='bold')\n",
    "plt.yticks(fontsize=16, weight='bold')\n",
    "plt.legend(title=\"Variant\", title_fontsize=18, fontsize=16)\n",
    "plt.grid(True, linestyle=\"--\", alpha=0.6)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# === Heatmap ===\n",
    "heatmap_data = df.groupby(\"type\")[[\"p_us\", \"p_uk\"]].mean().rename(\n",
    "    columns={\"p_us\": \"AmE\", \"p_uk\": \"BrE\"}\n",
    ").sort_index()\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "sns.heatmap(\n",
    "    heatmap_data,\n",
    "    annot=True,\n",
    "    fmt=\".2f\",\n",
    "    cmap=\"coolwarm\",\n",
    "    linewidths=0.5,\n",
    "    cbar_kws={\"label\": \"Avg. Probability\"}\n",
    ")\n",
    "plt.title(\"Mean Variant Probability by Type\", fontsize=16, weight='bold')\n",
    "plt.ylabel(\"Difference Type\", fontsize=16, weight='bold')\n",
    "plt.xticks(fontsize=18, weight='bold')\n",
    "plt.yticks(fontsize=18)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
