{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd\n",
    "import os \n",
    "import json \n",
    "from scipy.stats import spearmanr, pearsonr\n",
    "from datasets import load_dataset\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import sys\n",
    "import wandb \n",
    "from pathlib import Path\n",
    "\n",
    "import ast \n",
    "import itertools\n",
    "\n",
    "import torch \n",
    "\n",
    "from scipy.special import logsumexp\n",
    "\n",
    "# Add the project root to sys.path\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "# Now the import should work\n",
    "from weaver.dataset import VerificationDataset\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the following:\n",
    "\n",
    "```python run.py --config-name weak_supervision model_params.weak_supervision.drop_imbalanced_verifiers='large' data_cfg.dataset_name='AIMO'```\n",
    "\n",
    "```python run.py --config-name supervised model_params.naive_bayes.drop_imbalanced_verifiers='large' data_cfg.dataset_name='AIMO'```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load in all the data for comparing WS and NB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Replace with the specific run IDs of the above commands\n",
    "nb_run_id = \"kl9ej7cp\"\n",
    "ws_run_id = \"bo8iep8h\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "nb_run = api.run(f\"anonymous-research/verification/{nb_run_id}\")\n",
    "ws_run = api.run(f\"anonymous-research/verification/{ws_run_id}\")\n",
    "\n",
    "nb_data_cfg = nb_run.config['data_cfg']\n",
    "ws_data_cfg = ws_run.config['data_cfg']\n",
    "\n",
    "dataset = VerificationDataset(**nb_data_cfg)\n",
    "\n",
    "\n",
    "FIGURES_DIR = Path(\"./supervised_analysis/figures3\") # this is where all the results dfs are saved \n",
    "# for oracle setting we just look at test results (they are the same as the train results)\n",
    "ws_test_file = FIGURES_DIR / ws_run_id / f\"df_test.csv\"\n",
    "nb_test_file = FIGURES_DIR / nb_run_id / f\"df_test.csv\"\n",
    "\n",
    "ws_df = pd.read_csv(ws_test_file)\n",
    "nb_df = pd.read_csv(nb_test_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# because we do per-dataset modeling, just use the first sample's WS/NB params \n",
    "ws_tpr, ws_tnr, ws_fpr, ws_fnr = eval(ws_df.iloc[0]['model_params'], {\"array\": np.array})\n",
    "nb_tpr, nb_tnr, nb_fpr, nb_fnr = eval(nb_df.iloc[0]['model_params'], {\"array\": np.array})\n",
    "\n",
    "ws_selected_indices = ws_df.selected_idx.values\n",
    "nb_selected_indices = nb_df.selected_idx.values\n",
    "\n",
    "ws_correct = ws_df.select_acc.values\n",
    "nb_correct = nb_df.select_acc.values\n",
    "\n",
    "verifier_subset = ast.literal_eval(ws_df.iloc[0].verifier_subset)\n",
    "assert verifier_subset == ast.literal_eval(nb_df.iloc[0].verifier_subset) # these need to be the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores, labels = dataset.test_data\n",
    "\n",
    "binarize_threshold = nb_run.config['model_params']['naive_bayes']['binarize_threshold']\n",
    "assert binarize_threshold == ws_run.config['model_params']['weak_supervision']['binarize_threshold']\n",
    "\n",
    "binary_scores = (scores > binarize_threshold).astype(int)\n",
    "\n",
    "marginals = binary_scores.mean(axis=(0, 1))\n",
    "print(f\"All verifier marginals: {marginals}\")\n",
    "\n",
    "verifier_subset_indices = [i for i, v in enumerate(dataset.verifier_names) if v in verifier_subset]\n",
    "\n",
    "binary_scores = binary_scores[:, :, verifier_subset_indices]\n",
    "\n",
    "marginals = binary_scores.mean(axis=(0, 1))\n",
    "print(f\"Subset verifier marginals: {marginals}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Look at inverse covariance matrix (dependency structure)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_scores = np.hstack([binary_scores.reshape(-1, binary_scores.shape[-1]), labels.flatten()[:, np.newaxis]])\n",
    "cov = np.cov(all_scores.T)\n",
    "inv_cov = np.linalg.inv(cov)\n",
    "inv_cov = inv_cov[:, :-1]\n",
    "inv_cov = inv_cov[:-1, :]\n",
    "\n",
    "np.fill_diagonal(inv_cov, 0)\n",
    "\n",
    "plt.imshow(np.abs(inv_cov), cmap='viridis', interpolation='nearest')\n",
    "plt.colorbar()\n",
    "\n",
    "y_labels = x_labels = verifier_subset\n",
    "\n",
    "plt.xticks(np.arange(inv_cov.shape[1]), x_labels, rotation=45, ha=\"right\")\n",
    "plt.yticks(np.arange(inv_cov.shape[0]), y_labels)\n",
    "\n",
    "# Adjust layout to spread out labels\n",
    "plt.gcf().subplots_adjust(bottom=0.3)  # Increase bottom margin for better label spacing\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compare parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Comparing WS and NB TPR/TNRs.\")\n",
    "for i in range(len(ws_tpr)):\n",
    "    print(verifier_subset[i])\n",
    "    print(f\"WS: TPR = {float(ws_tpr[i])} TNR =  {float(ws_tnr[i])}\")\n",
    "    print(f\"NB: TPR = {float(nb_tpr[i])}, TNR = {float(nb_tnr[i])}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_posterior(tpr, tnr, binary_scores, cb, fpr=None, fnr=None):\n",
    "    if fpr is None:\n",
    "        fpr = 1 - tnr\n",
    "    if fnr is None:\n",
    "        fnr = 1 - tpr\n",
    "\n",
    "    # Compute log-likelihoods instead of products\n",
    "    log_likelihood_y1 = np.sum(\n",
    "        np.log(binary_scores * tpr + (1 - binary_scores) * fnr), axis=-1\n",
    "    )  # (problems x samples)\n",
    "\n",
    "    log_likelihood_y0 = np.sum(\n",
    "        np.log(binary_scores * fpr + (1 - binary_scores) * tnr), axis=-1\n",
    "    )  # (problems x samples)\n",
    "\n",
    "    # Compute log posteriors\n",
    "    log_posterior_y1 = log_likelihood_y1 + np.log(cb)\n",
    "    log_posterior_y0 = log_likelihood_y0 + np.log(1 - cb)\n",
    "\n",
    "    # Normalize using logsumexp for numerical stability\n",
    "    log_prob_y1_given_features = log_posterior_y1 - logsumexp(\n",
    "        [log_posterior_y1, log_posterior_y0], axis=0\n",
    "    )\n",
    "\n",
    "    # Convert back to probability space\n",
    "    prob_y1_given_features = np.exp(log_prob_y1_given_features)\n",
    "\n",
    "    return prob_y1_given_features\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cb = labels.mean()\n",
    "full_score_distribution = np.array(list(itertools.product([0, 1], repeat=binary_scores.shape[-1])))\n",
    "\n",
    "ws_posterior = get_posterior(ws_tpr, ws_tnr, full_score_distribution, cb, ws_fpr, ws_fnr)\n",
    "nb_posterior = get_posterior(nb_tpr, nb_tnr, full_score_distribution, cb)\n",
    "\n",
    "\n",
    "mse = np.mean((ws_posterior - nb_posterior) ** 2)\n",
    "s = spearmanr(ws_posterior, nb_posterior)[0]\n",
    "print(f\"Comparing WS and NB posteriors. MSE: {mse}, Spearman: {s}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Assess if objective function is truly minimized with WS weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def measure_misspecification_error(binary_scores_flattened, cb, tpr, tnr, fpr, fnr):\n",
    "    misspec_error = []\n",
    "    for (vi, vj) in itertools.combinations(range(binary_scores_flattened.shape[-1]), 2):\n",
    "        for a, b in itertools.product([0, 1], repeat=2):\n",
    "            scores_i = binary_scores_flattened[:, vi]\n",
    "            scores_j = binary_scores_flattened[:, vj]\n",
    "            total = len(scores_i) # number of samples\n",
    "            p_ab = np.sum((scores_i == a) & (scores_j == b)) / total # Pr(vi = a, vj = b)\n",
    "            if a == 1 and b == 1:\n",
    "                estimate = ws_tpr[vi] * ws_tpr[vj] * cb + ws_fpr[vi] * ws_fpr[vj] * (1 - cb)\n",
    "            elif a == 1 and b == 0:\n",
    "                estimate = ws_tpr[vi] * ws_fnr[vj] * cb + ws_fpr[vi] * ws_tnr[vj] * (1 - cb)\n",
    "            elif a == 0 and b == 1:\n",
    "                estimate = ws_fnr[vi] * ws_tpr[vj] * cb + ws_tnr[vi] * ws_fpr[vj] * (1 - cb)\n",
    "            else:\n",
    "                estimate = ws_fnr[vi] * ws_fnr[vj] * cb + ws_tnr[vi] * ws_tnr[vj] * (1 - cb)\n",
    "\n",
    "            # estimate = e.g., TPR_i * TPR_j * Pr(correct) + FPR_i * FPR_j * Pr(incorrect)\n",
    "            # Error: Pr(vi = a, vj = b) - estimate\n",
    "            err = np.abs(p_ab - estimate)\n",
    "            misspec_error.append(err)\n",
    "            print(f\"WS Error for Pr(v{vi} = {a}, v{vj} = {b}) = {err}\")\n",
    "\n",
    "    misspec_error = np.array(misspec_error)\n",
    "    return misspec_error.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "binary_scores_flattened = binary_scores.reshape(-1, binary_scores.shape[-1])\n",
    "labels_flattened = labels.flatten()\n",
    "\n",
    "cb = labels_flattened.mean()\n",
    "\n",
    "ws_misspec_error = measure_misspecification_error(binary_scores_flattened, cb, ws_tpr, ws_tnr, ws_fpr, ws_fnr)\n",
    "print(f\"WS misspecification error: {ws_misspec_error}\\n\\n\")\n",
    "\n",
    "nb_misspec_error = measure_misspecification_error(binary_scores_flattened, cb, nb_tpr, nb_tnr, nb_fpr, nb_fnr)\n",
    "print(f\"NB misspecification error: {nb_misspec_error}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check training obj value with NB parameters (is it lower?)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If NB loss is much lower, try initializing a mu_init close to nb_mu.\n",
    "If that mu_init converges to nb_mu, we have multiple solutions / WS does not have enough 'signal' from the verifiers.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from metal import LabelModel\n",
    "\n",
    "print(f\"WARNING: THIS CODE IS MEANT TO BE EXACTLY SAME AS THE WS COMMAND YOU RAN IN {ws_run_id}\")\n",
    "\n",
    "lr = ws_run.config['model_params']['weak_supervision']['lr']\n",
    "seed = ws_run.config['model_params']['weak_supervision']['seed']\n",
    "cb_args = ws_run.config['model_params']['weak_supervision']['cb_args']\n",
    "n_epochs = ws_run.config['model_params']['weak_supervision']['n_epochs']\n",
    "mu_epochs = ws_run.config['model_params']['weak_supervision']['mu_epochs']\n",
    "\n",
    "label_model = LabelModel(k=2, seed=seed)\n",
    "votes_scaled = binary_scores_flattened + 1\n",
    "\n",
    "\n",
    "if not ws_run.config['model_params']['weak_supervision']['use_label_on_test']:\n",
    "    cb_args = 0.5\n",
    "else:\n",
    "    cb_args = cb_args['class_balance'] \n",
    "\n",
    "if type(cb_args) == str and cb_args == \"labels\":\n",
    "    mean_correctness = labels.mean()\n",
    "    class_balance = np.asarray([1- mean_correctness, mean_correctness])\n",
    "elif type(cb_args) == float:\n",
    "    class_balance = np.asarray([1- cb_args, cb_args])\n",
    "else:\n",
    "    raise ValueError(f\"Unknown class balance: {cb_args}\")\n",
    "\n",
    "cb_inputs ={\"class_balance\": class_balance}\n",
    "\n",
    "label_model.train_model(\n",
    "    votes_scaled, \n",
    "    L_train_continuous=None,\n",
    "    abstains=False, \n",
    "    symmetric=False, \n",
    "    n_epochs=n_epochs,\n",
    "    mu_epochs=mu_epochs,\n",
    "    log_train_every=1000,\n",
    "    lr=lr,\n",
    "    **cb_inputs,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu0 = np.array([val for pair in zip(nb_tnr, nb_fpr) for val in pair])\n",
    "mu1 = np.array([val for pair in zip(nb_fnr, nb_tpr) for val in pair])\n",
    "\n",
    "nb_mu = torch.Tensor(np.hstack([mu0[:, np.newaxis], mu1[:, np.newaxis]]))\n",
    "\n",
    "ws_mu0 = np.array([val for pair in zip(ws_tnr, ws_fpr) for val in pair])\n",
    "ws_mu1 = np.array([val for pair in zip(ws_fnr, ws_tpr) for val in pair])\n",
    "ws_mu = torch.Tensor(np.hstack([ws_mu0[:, np.newaxis], ws_mu1[:, np.newaxis]]))\n",
    "\n",
    "\n",
    "\n",
    "nb_loss_1 = torch.norm((label_model.O - nb_mu @ label_model.P @ nb_mu.t())[label_model.mask]) ** 2\n",
    "nb_loss_2 = torch.norm(torch.sum(nb_mu @ label_model.P, 1) - torch.diag(label_model.O)) ** 2\n",
    "\n",
    "\n",
    "print(f\"NB loss: {nb_loss_1 + nb_loss_2}\")\n",
    "\n",
    "ws_loss_1 = torch.norm((label_model.O - ws_mu @ label_model.P @ ws_mu.t())[label_model.mask]) ** 2\n",
    "ws_loss_2 = torch.norm(torch.sum(ws_mu @ label_model.P, 1) - torch.diag(label_model.O)) ** 2\n",
    "print(f\"WS loss: {ws_loss_1 + ws_loss_2}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Other tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.round(label_model.O, decimals=2) # this shows us Pr(v_i = a, v_j = b)\n",
    "# we can see that these joint distributions are extremely skewed. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mayeeenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
