{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-09-23 18:04:00,529 - INFO - PyTorch version 2.3.0 available.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import grp\n",
    "import json\n",
    "import seaborn as sns\n",
    "from treatment_effects import calculate_treatment_effects, calculate_rewrite_effect, calculate_unpaired_treatment_effect\n",
    "from datasets import load_dataset\n",
    "from pathlib import Path\n",
    "from transformers import DistilBertTokenizer\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score\n",
    "\n",
    "# set random seeds\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "intervened_concept = \"helpfulness\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset_from_json(filepath):\n",
    "    if filepath.suffix == \".jsonl\":\n",
    "        with filepath.open(\"r\") as file:\n",
    "            dataset = {idx: json.loads(line) for idx, line in enumerate(file)}\n",
    "    elif filepath.suffix == \".json\":\n",
    "        with filepath.open(\"r\") as file:\n",
    "            dataset = json.load(file)\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            \"Unsupported file format. Only .json and .jsonl are supported.\"\n",
    "        )\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the json file\n",
    "data_dir = Path(\"/net/projects/user/prompt_distributions/data/scored\")\n",
    "complete_dir = data_dir / \"complete\"\n",
    "\n",
    "helpful_filename = \"helpsteer_helpfulness_complete_scored.jsonl\"\n",
    "helpful_path = complete_dir / helpful_filename\n",
    "\n",
    "complexity_filename = \"helpsteer_complexity_complete_scored.jsonl\"\n",
    "complexity_path = complete_dir / complexity_filename\n",
    "\n",
    "# assert the file exists\n",
    "assert complete_dir.exists(), f\"Directory {complete_dir} does not exist.\"\n",
    "for path in [helpful_path, complexity_path]:\n",
    "    assert path.exists(), f\"File {path} does not exist.\"\n",
    "\n",
    "# Read each JSON file into a dictionary\n",
    "if intervened_concept == \"helpfulness\":\n",
    "    data_intervened = load_dataset_from_json(helpful_path)\n",
    "    data_spurious = load_dataset_from_json(complexity_path)\n",
    "else:\n",
    "    data_intervened = load_dataset_from_json(complexity_path)\n",
    "    data_spurious = load_dataset_from_json(helpful_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Intervened dataset has 17343 positive examples and 7658 negative examples.\n",
      "Spurious dataset has 2053 positive examples and 22948 negative examples.\n"
     ]
    }
   ],
   "source": [
    "# check how many each dataset has of each w_original value True or False\n",
    "num_intervened_positive = sum([1 for idx, item in data_intervened.items() if item[\"w_original\"]])\n",
    "num_intervened_negative = sum([1 for idx, item in data_intervened.items() if not item[\"w_original\"]])\n",
    "\n",
    "num_spurious_positive = sum([1 for idx, item in data_spurious.items() if item[\"w_original\"]])\n",
    "num_spurious_negative = sum([1 for idx, item in data_spurious.items() if not item[\"w_original\"]])\n",
    "\n",
    "print(f\"Intervened dataset has {num_intervened_positive} positive examples and {num_intervened_negative} negative examples.\")\n",
    "print(f\"Spurious dataset has {num_spurious_positive} positive examples and {num_spurious_negative} negative examples.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Print the first 2 keys and their corresponding values\n",
    "# for data in [data_intervened, data_spurious]:\n",
    "#     keys = list(data.keys())\n",
    "#     assert keys == sorted(keys), \"Keys are not sorted.\"\n",
    "#     sample_size = len(keys)\n",
    "#     print(\"Number of samples:\", sample_size)\n",
    "#     for key in keys[:2]:\n",
    "#         print(f\"{key}:\")\n",
    "#         print(json.dumps(data[key], indent=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Next: Create the base dataset.\n",
    "\n",
    "def create_groups(data_intervened, data_spurious):\n",
    "    keys_00 = []\n",
    "    keys_01 = []\n",
    "    keys_10 = []\n",
    "    keys_11 = []\n",
    "\n",
    "    for key in tqdm(data_intervened.keys()):\n",
    "        assert data_intervened[key][\"completions\"][\"original\"] == data_spurious[key][\"completions\"][\"original\"], \"Original completions do not match.\"\n",
    "        concept1_label = data_intervened[key][\"w_original\"]\n",
    "        concept2_label = data_spurious[key][\"w_original\"]\n",
    "\n",
    "        # Append to appropriate lists\n",
    "        if concept1_label == 0 and concept2_label == 0:\n",
    "            keys_00.append(key)\n",
    "        elif concept1_label == 0 and concept2_label == 1:\n",
    "            keys_01.append(key)\n",
    "        elif concept1_label == 1 and concept2_label == 0:\n",
    "            keys_10.append(key)\n",
    "        else:\n",
    "            keys_11.append(key)\n",
    "    \n",
    "    return keys_11, keys_01, keys_10, keys_00\n",
    "\n",
    "def create_base_dataset(data_intervened,data_spurious):\n",
    "    keys_11, keys_01, keys_10, keys_00 = create_groups(data_intervened, data_spurious)\n",
    "\n",
    "    print(\"Number of samples in each group:\")\n",
    "    print(\"00:\", len(keys_00))\n",
    "    print(\"01:\", len(keys_01))\n",
    "    print(\"10:\", len(keys_10))\n",
    "    print(\"11:\", len(keys_11))\n",
    "\n",
    "    print(\"Number of samples in each concept:\")\n",
    "    print(\"Intervened concept = 0:\", len(keys_00) + len(keys_01))\n",
    "    print(\"Intervened concept = 1:\", len(keys_10) + len(keys_11))\n",
    "    print(\"Spurious concept = 0:\", len(keys_00) + len(keys_10))\n",
    "    print(\"Spurious concept = 1:\", len(keys_01) + len(keys_11))\n",
    "\n",
    "    min_count = min(len(keys_00), len(keys_01), len(keys_10), len(keys_11))\n",
    "    print(\"Minimum number of samples in each group:\", min_count)\n",
    "\n",
    "    # Create the base dataset\n",
    "    sampled_00 = np.random.choice(keys_00, min_count, replace=False)\n",
    "    sampled_01 = np.random.choice(keys_01, min_count, replace=False)\n",
    "    sampled_10 = np.random.choice(keys_10, min_count, replace=False)\n",
    "    sampled_11 = np.random.choice(keys_11, min_count, replace=False)\n",
    "\n",
    "    filtered_groups = [sampled_11, sampled_01, sampled_10, sampled_00]\n",
    "    base_keys = sorted(np.concatenate(filtered_groups))\n",
    "    # Make sure we pull from the intervened dataset\n",
    "    base_dataset = {key: data_intervened[key] for key in base_keys}\n",
    "\n",
    "    print(\"Number of samples in the base dataset:\", len(base_dataset))\n",
    "    assert len(base_dataset) == 4 * min_count, \"Base dataset does not have the correct number of samples.\"\n",
    "    return base_dataset, min_count, filtered_groups\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25001/25001 [00:00<00:00, 1149245.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of samples in each group:\n",
      "00: 7363\n",
      "01: 295\n",
      "10: 15585\n",
      "11: 1758\n",
      "Number of samples in each concept:\n",
      "Intervened concept = 0: 7658\n",
      "Intervened concept = 1: 17343\n",
      "Spurious concept = 0: 22948\n",
      "Spurious concept = 1: 2053\n",
      "Minimum number of samples in each group: 295\n",
      "Number of samples in the base dataset: 1180\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Create the base dataset\n",
    "base_dataset, min_count, filtered_group = create_base_dataset(data_intervened, data_spurious)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = data_dir / \"synthetic_helpsteer\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save this base dataset for future use, keeping original name\n",
    "save_file = \"synthetic_helpsteer_base.jsonl\"\n",
    "save_path = save_dir / save_file\n",
    "\n",
    "def save_dataset(dataset,save_path):\n",
    "    # If it already exists, delete it\n",
    "    if (save_path).exists():\n",
    "        os.remove(save_path)\n",
    "\n",
    "    # make sure .jsonl\n",
    "    assert save_path.suffix == \".jsonl\", \"Save path must have .jsonl extension.\"\n",
    "    with open(save_path, \"w\") as f:\n",
    "        for key in dataset.keys():\n",
    "            json.dump(dataset[key], f)\n",
    "            f.write(\"\\n\")\n",
    "\n",
    "    # Make sure the permissions are read-only for everyone\n",
    "    os.chmod(save_path, 0o444)\n",
    "\n",
    "    # Get the GID of the group by its name\n",
    "    group_name = \"user-lab\"\n",
    "    gid = grp.getgrnam(group_name).gr_gid\n",
    "\n",
    "    # Change the group of the file (leave UID unchanged by passing -1)\n",
    "    os.chown(save_path, -1, gid)\n",
    "\n",
    "save_dataset(base_dataset, save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Base dataset size: 1180\n",
      "Minimum count per group: 295\n",
      "Counts for each correlation level\n",
      "Correlation level 0: [295 295 295 295]\n",
      "Total samples: 1180\n",
      "Correlation level 1: [295 241 241 295]\n",
      "Total samples: 1072\n",
      "Correlation level 2: [295 196 196 295]\n",
      "Total samples: 982\n",
      "Correlation level 3: [295 158 158 295]\n",
      "Total samples: 906\n",
      "Correlation level 4: [295 126 126 295]\n",
      "Total samples: 842\n",
      "Correlation level 5: [295  98  98 295]\n",
      "Total samples: 786\n",
      "Correlation level 6: [295  73  73 295]\n",
      "Total samples: 736\n",
      "Correlation level 7: [295  51  51 295]\n",
      "Total samples: 692\n",
      "Correlation level 8: [295  32  32 295]\n",
      "Total samples: 654\n",
      "Correlation level 9: [295  15  15 295]\n",
      "Total samples: 620\n",
      "Correlation level 10: [295   0   0 295]\n",
      "Total samples: 590\n"
     ]
    }
   ],
   "source": [
    "# Now we can create datasets with increasing correlation\n",
    "\n",
    "def get_counts(N, max_per_group):\n",
    "    # Base probabilities\n",
    "    P_positive = 0.5\n",
    "    P_negative = 0.5\n",
    "\n",
    "    counts = []\n",
    "\n",
    "    # Loop through each desired correlation level\n",
    "    for i in range(11):\n",
    "        P_long_given_positive = 0.5 + i * 0.05\n",
    "        P_long_given_negative = 0.5 - i * 0.05\n",
    "\n",
    "        # Calculate joint probabilities\n",
    "        P_pos_long = P_positive * P_long_given_positive\n",
    "        P_pos_short = P_positive * (1 - P_long_given_positive)\n",
    "        P_neg_long = P_negative * P_long_given_negative\n",
    "        P_neg_short = P_negative * (1 - P_long_given_negative)\n",
    "\n",
    "        # Calculate sample counts\n",
    "        n_long_positive = int(N * P_pos_long)\n",
    "        n_short_positive = int(N * P_pos_short)\n",
    "        n_long_negative = int(N * P_neg_long)\n",
    "        n_short_negative = int(N * P_neg_short)\n",
    "        counts.append((n_long_positive, n_long_negative, n_short_positive, n_short_negative))\n",
    "    \n",
    "    # convert to numpy array for easier manipulation\n",
    "    counts = np.array(counts)\n",
    "    max_groups = counts.max(axis=1)\n",
    "    # Ensure that counts do not exceed max_per_group, while maintaining proportions\n",
    "    counts = (counts / max_groups[:, None]) * max_per_group\n",
    "    return counts.astype(int)\n",
    "\n",
    "# Get counts for a sample size of len(base_dataset)\n",
    "print(\"Base dataset size:\", len(base_dataset))\n",
    "print(\"Minimum count per group:\", min_count)\n",
    "counts = get_counts(len(base_dataset), min_count)\n",
    "print(\"Counts for each correlation level\")\n",
    "for i, count in enumerate(counts):\n",
    "    print(f\"Correlation level {i}: {count}\")\n",
    "    print(f\"Total samples: {sum(count)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Smallest sample size across correlation levels: 590\n",
      "Adjusted counts for each correlation level\n",
      "Correlation level 0: [147 147 147 147]\n",
      "Total samples: 588\n",
      "Correlation level 1: [162 132 132 162]\n",
      "Total samples: 588\n",
      "Correlation level 2: [177 117 117 177]\n",
      "Total samples: 588\n",
      "Correlation level 3: [192 102 102 192]\n",
      "Total samples: 588\n",
      "Correlation level 4: [206  88  88 206]\n",
      "Total samples: 588\n",
      "Correlation level 5: [221  73  73 221]\n",
      "Total samples: 588\n",
      "Correlation level 6: [236  58  58 236]\n",
      "Total samples: 588\n",
      "Correlation level 7: [251  43  43 251]\n",
      "Total samples: 588\n",
      "Correlation level 8: [266  28  28 266]\n",
      "Total samples: 588\n",
      "Correlation level 9: [280  14  14 280]\n",
      "Total samples: 588\n",
      "Correlation level 10: [295   0   0 295]\n",
      "Total samples: 590\n",
      "Proportions of each group:\n",
      "Correlation level 0: [0.25 0.25 0.25 0.25]\n",
      "Correlation level 1: [0.276 0.224 0.224 0.276]\n",
      "Correlation level 2: [0.301 0.199 0.199 0.301]\n",
      "Correlation level 3: [0.327 0.173 0.173 0.327]\n",
      "Correlation level 4: [0.35 0.15 0.15 0.35]\n",
      "Correlation level 5: [0.376 0.124 0.124 0.376]\n",
      "Correlation level 6: [0.401 0.099 0.099 0.401]\n",
      "Correlation level 7: [0.427 0.073 0.073 0.427]\n",
      "Correlation level 8: [0.452 0.048 0.048 0.452]\n",
      "Correlation level 9: [0.476 0.024 0.024 0.476]\n",
      "Correlation level 10: [0.5 0.  0.  0.5]\n",
      "True proportions:\n",
      "Correlation level 0: [0.25, 0.25, 0.25, 0.25]\n",
      "Correlation level 1: [0.275, 0.225, 0.225, 0.275]\n",
      "Correlation level 2: [0.3, 0.2, 0.2, 0.3]\n",
      "Correlation level 3: [0.325, 0.175, 0.175, 0.325]\n",
      "Correlation level 4: [0.35, 0.15, 0.15, 0.35]\n",
      "Correlation level 5: [0.375, 0.125, 0.125, 0.375]\n",
      "Correlation level 6: [0.4, 0.1, 0.1, 0.4]\n",
      "Correlation level 7: [0.425, 0.075, 0.075, 0.425]\n",
      "Correlation level 8: [0.45, 0.05, 0.05, 0.45]\n",
      "Correlation level 9: [0.475, 0.025, 0.025, 0.475]\n",
      "Correlation level 10: [0.5, 0.0, 0.0, 0.5]\n"
     ]
    }
   ],
   "source": [
    "# To ensure that each dataset has the same number of samples, we reduce the number of samples in earlier datasets\n",
    "# to match the smallest dataset size\n",
    "smallest_sample_size = np.min(counts.sum(axis=1))\n",
    "print(\"Smallest sample size across correlation levels:\", smallest_sample_size)\n",
    "\n",
    "# Adjust counts to ensure all datasets have the same number of samples\n",
    "adjusted_counts = counts * (smallest_sample_size / counts.sum(axis=1)[:, None])\n",
    "adjusted_counts = adjusted_counts.astype(int)\n",
    "\n",
    "# Check that all counts are valid\n",
    "# for count in counts:\n",
    "#     assert count.sum() == smallest_sample_size, \"Counts do not sum to the smallest sample size.\"\n",
    "\n",
    "print(\"Adjusted counts for each correlation level\")\n",
    "for i, count in enumerate(adjusted_counts):\n",
    "    print(f\"Correlation level {i}: {count}\")\n",
    "    print(f\"Total samples: {sum(count)}\")\n",
    "\n",
    "# print table of proprtions\n",
    "print(\"Proportions of each group:\")\n",
    "for i, count in enumerate(adjusted_counts):\n",
    "    total = sum(count)\n",
    "    proportions = count / total\n",
    "    print(f\"Correlation level {i}: {np.round(proportions,3)}\")\n",
    "\n",
    "print(\"True proportions:\")\n",
    "P_positive = 0.5\n",
    "P_negative = 0.5\n",
    "for i in range(11):\n",
    "    P_long_given_positive = 0.5 + i * 0.05\n",
    "    P_long_given_negative = 0.5 - i * 0.05\n",
    "\n",
    "    # Calculate joint probabilities\n",
    "    P_pos_long = np.round(P_positive * P_long_given_positive,3)\n",
    "    P_pos_short = np.round(P_positive * (1 - P_long_given_positive),3)\n",
    "    P_neg_long = np.round(P_negative * P_long_given_negative,3)\n",
    "    P_neg_short = np.round(P_negative * (1 - P_long_given_negative),3)\n",
    "\n",
    "    print(f\"Correlation level {i}: {[P_pos_long, P_neg_long, P_pos_short, P_neg_short]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adjusted counts:\n",
      "00,  01,  10,  11\n",
      "[[147 147 147 147]\n",
      " [162 132 132 162]\n",
      " [177 117 117 177]\n",
      " [192 102 102 192]\n",
      " [206  88  88 206]\n",
      " [221  73  73 221]\n",
      " [236  58  58 236]\n",
      " [251  43  43 251]\n",
      " [266  28  28 266]\n",
      " [280  14  14 280]\n",
      " [295   0   0 295]]\n"
     ]
    }
   ],
   "source": [
    "print(\"Adjusted counts:\")\n",
    "print(\"00,  01,  10,  11\")\n",
    "print(adjusted_counts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now we can create the datasets with increasing correlation\n",
    "\n",
    "def create_correlated_datasets(base_dataset, adjusted_counts, filtered_groups):\n",
    "    long_positive_filtered, long_negative_filtered, short_positive_filtered, short_negative_sampled = filtered_groups\n",
    "    datasets = []\n",
    "    \n",
    "    for i, (n_long_positive, n_long_negative, n_short_positive, n_short_negative) in enumerate(adjusted_counts):\n",
    "        long_positive = np.random.choice(long_positive_filtered, n_long_positive, replace=False)\n",
    "        long_negative = np.random.choice(long_negative_filtered, n_long_negative, replace=False)\n",
    "        short_positive = np.random.choice(short_positive_filtered, n_short_positive, replace=False)\n",
    "        short_negative = np.random.choice(short_negative_sampled, n_short_negative, replace=False)\n",
    "\n",
    "        # Combine sampled data into new dataset\n",
    "        correlated_keys = np.concatenate((long_positive, long_negative, short_positive, short_negative))\n",
    "        correlated_keys = sorted(correlated_keys.tolist())\n",
    "        correlated_dataset = {key: base_dataset[key] for key in correlated_keys}\n",
    "        datasets.append(correlated_dataset)\n",
    "    return datasets\n",
    "\n",
    "correlated_datasets = create_correlated_datasets(base_dataset, adjusted_counts, filtered_group)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "synthetic_helpsteer_base_correlated_0.jsonl\n",
      "synthetic_helpsteer_base_correlated_1.jsonl\n",
      "synthetic_helpsteer_base_correlated_2.jsonl\n",
      "synthetic_helpsteer_base_correlated_3.jsonl\n",
      "synthetic_helpsteer_base_correlated_4.jsonl\n",
      "synthetic_helpsteer_base_correlated_5.jsonl\n",
      "synthetic_helpsteer_base_correlated_6.jsonl\n",
      "synthetic_helpsteer_base_correlated_7.jsonl\n",
      "synthetic_helpsteer_base_correlated_8.jsonl\n",
      "synthetic_helpsteer_base_correlated_9.jsonl\n",
      "synthetic_helpsteer_base_correlated_10.jsonl\n"
     ]
    }
   ],
   "source": [
    "# Save each correlated dataset\n",
    "for i, dataset in enumerate(correlated_datasets):\n",
    "    filename = save_file.split(\".\")[0] + f\"_correlated_{i}.jsonl\"\n",
    "    print(filename)\n",
    "    save_path = save_dir / filename\n",
    "    save_dataset(dataset, save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "editeval",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
