{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "import os\n",
    "import sys\n",
    "from mindreadingautobots.sequence_generators import data_io"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_multitask_sparse_majority(n_tasks, n_bits, k, n_data, p_bitflip=0.0, seed=None):\n",
    "    \"\"\"\n",
    "    Generate a multitask sparse majority dataset.\n",
    "    \n",
    "    Args:\n",
    "        n_tasks: Number of subtasks (distinct versions of sparse majority)\n",
    "        n_bits: Total length of task bits\n",
    "        k: Size of the fixed subset for each majority calculation\n",
    "        n_data: Number of data points to generate\n",
    "        p_bitflip: Probability of flipping bits in the task bits (not control bits)\n",
    "        seed: Random seed for reproducibility\n",
    "        \n",
    "    Returns:\n",
    "        X: Array of shape (n_data, n_tasks + n_bits + 1) containing noiseless data:\n",
    "           - First n_tasks bits are control bits (one-hot encoding of task)\n",
    "           - Next n_bits are task bits\n",
    "           - Last bit is the output (majority of relevant task bits)\n",
    "        Z: Array of same shape as X but with noise in the task bits (if p_bitflip > 0)\n",
    "        task_subsets: List of k indices for each task indicating which bits to use for majority\n",
    "    \"\"\"\n",
    "    if seed is not None:\n",
    "        np.random.seed(seed)\n",
    "    \n",
    "    # Initialize the dataset\n",
    "    total_bits = n_tasks + n_bits + 1  # control bits + task bits + output bit\n",
    "    X = np.zeros((n_data, total_bits), dtype=np.int32)\n",
    "    \n",
    "    # Generate random task subsets (each task uses a different subset of k indices)\n",
    "    task_subsets = []\n",
    "    for i in range(n_tasks):\n",
    "        # Generate a random subset of k indices from the task bits\n",
    "        subset = np.sort(np.random.choice(n_bits, k, replace=False))\n",
    "        task_subsets.append(subset)\n",
    "    \n",
    "    # Generate data for each example\n",
    "    for i in range(n_data):\n",
    "        # Randomly select a task (which control bit to activate)\n",
    "        active_task = np.random.randint(0, n_tasks)\n",
    "        \n",
    "        # Set the control bit (one-hot encoding)\n",
    "        X[i, active_task] = 1\n",
    "        \n",
    "        # Generate random task bits\n",
    "        task_bits = np.random.randint(0, 2, n_bits)\n",
    "        X[i, n_tasks:n_tasks+n_bits] = task_bits\n",
    "        \n",
    "        # Compute the majority of the subset corresponding to the active task\n",
    "        relevant_subset = task_subsets[active_task]\n",
    "        relevant_bits = task_bits[relevant_subset]\n",
    "        # If sum >= k/2, majority is 1, otherwise 0\n",
    "        # For ties (sum = k/2), output 1 as specified\n",
    "        majority = 1 if np.sum(relevant_bits) >= k/2 else 0\n",
    "        \n",
    "        # Set the output bit\n",
    "        X[i, -1] = majority\n",
    "    \n",
    "    # Apply noise to task bits if specified\n",
    "    if p_bitflip > 0:\n",
    "        # Create a copy of X\n",
    "        Z = np.copy(X)\n",
    "        \n",
    "        # Generate noise mask for task bits only\n",
    "        flips = np.random.binomial(1, p_bitflip, size=(n_data, n_bits))\n",
    "        \n",
    "        # Apply noise to task bits only\n",
    "        Z[:, n_tasks:n_tasks+n_bits] = np.logical_xor(\n",
    "            X[:, n_tasks:n_tasks+n_bits], \n",
    "            flips\n",
    "        ).astype(np.int32)\n",
    "        \n",
    "        # Keep the original output bit from X to test noise robustness\n",
    "        # (not recomputing based on noisy bits)\n",
    "        Z[:, -1] = X[:, -1]\n",
    "\n",
    "        print(X, Z, task_subsets)            \n",
    "        return X, Z, task_subsets\n",
    "\n",
    "    else:\n",
    "        return X, X, task_subsets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "unterminated string literal (detected at line 25) (2332664692.py, line 25)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Cell \u001b[0;32mIn[35], line 25\u001b[0;36m\u001b[0m\n\u001b[0;31m    print(f\"Task bits: {''.join(\u001b[0m\n\u001b[0m          ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m unterminated string literal (detected at line 25)\n"
     ]
    }
   ],
   "source": [
    "def verify_examples(X, task_subsets, n_tasks, num_examples=5):\n",
    "    \"\"\"Verify correctness of several examples in the dataset\"\"\"\n",
    "    for i in range(min(num_examples, len(X))):\n",
    "        example = X[i]\n",
    "        \n",
    "        # Determine which task is active\n",
    "        active_task = np.argmax(example[:n_tasks])\n",
    "        \n",
    "        # Get the task bits\n",
    "        task_bits = example[n_tasks:-1]\n",
    "        \n",
    "        # Get the subset for the active task\n",
    "        relevant_subset = task_subsets[active_task]\n",
    "        relevant_bits = task_bits[relevant_subset]\n",
    "        \n",
    "        # Compute expected parity\n",
    "        expected_parity = np.sum(relevant_bits) % 2\n",
    "        actual_parity = example[-1]\n",
    "        \n",
    "        print(f\"Data: {X}\")\n",
    "        # Display the example\n",
    "        print(f\"\\nExample {i+1}:\")\n",
    "        print(f\"Full string: {''.join([str(x) for x in example[:-1]])}\")\n",
    "        print(f\"Control bits: {''.join([str(x) for x in example[:n_tasks]])}\")\n",
    "        print(f\"Task bits: {''.join(\n",
    "            [str(x) for x in task_bits])}\")\n",
    "        print(f\"Active task: {active_task+1}\")\n",
    "        print(f\"Relevant subset indices: {relevant_subset}\")\n",
    "        print(f\"Relevant bits: {relevant_bits}\")\n",
    "        print(f\"Expected answer: {expected_parity}\")\n",
    "        print(f\"Actual answer: {actual_parity}\")\n",
    "        print(f\"Correct: {expected_parity == actual_parity}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Generating multitask_sparse_parity with 4 tasks, 20 task bits, k=3, noise=0.05\n",
      "[[0 0 0 ... 0 1 0]\n",
      " [1 0 0 ... 0 1 1]\n",
      " [1 0 0 ... 0 0 1]\n",
      " ...\n",
      " [0 0 1 ... 1 1 0]\n",
      " [0 0 0 ... 1 0 1]\n",
      " [1 0 0 ... 0 1 0]] [[0 0 0 ... 0 1 0]\n",
      " [1 0 0 ... 0 1 1]\n",
      " [1 0 0 ... 0 0 1]\n",
      " ...\n",
      " [0 0 1 ... 1 1 0]\n",
      " [0 0 0 ... 1 0 1]\n",
      " [1 0 0 ... 0 1 0]] [array([ 2,  3, 13]), array([ 8, 13, 19]), array([ 2,  3, 17]), array([ 0,  7, 18])]\n",
      "Task 1 uses bits [ 2  3 13] for parity calculation\n",
      "Task 2 uses bits [ 8 13 19] for parity calculation\n",
      "Task 3 uses bits [ 2  3 17] for parity calculation\n",
      "Task 4 uses bits [ 0  7 18] for parity calculation\n",
      "[array([ 2,  3, 13]), array([ 8, 13, 19]), array([ 2,  3, 17]), array([ 0,  7, 18])]\n",
      "\n",
      "Comparing original and noisy examples:\n",
      "Example 1:\n",
      "Original task bits: [0 0 0 1 1 0 0 0 0 0 1 1 1 1 0 1 0 1 0 1]\n",
      "Noisy task bits: [0 0 0 1 1 0 0 0 0 0 1 1 1 1 0 1 0 0 0 1]\n",
      "Bits flipped: 1\n",
      "Original answer: 0\n",
      "Noisy answer: 0\n",
      "\n",
      "Example 2:\n",
      "Original task bits: [1 0 1 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 0 1]\n",
      "Noisy task bits: [1 1 1 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 0 1]\n",
      "Bits flipped: 1\n",
      "Original answer: 1\n",
      "Noisy answer: 1\n",
      "\n",
      "\n",
      "Printing saved data:\n",
      "\n",
      "Contents of noiseless_train.pkl:\n",
      "Number of examples: 2000\n",
      "First example:\n",
      "Line: 000100011000001111010101\n",
      "Label: 0\n",
      "\n",
      "Contents of noiseless_val.pkl:\n",
      "Number of examples: 10000\n",
      "First example:\n",
      "Line: 001011001110100111111010\n",
      "Label: 0\n",
      "\n",
      "Contents of train.pkl:\n",
      "Number of examples: 2000\n",
      "First example:\n",
      "Line: 000100011000001111010001\n",
      "Label: 0\n",
      "\n",
      "Contents of val.pkl:\n",
      "Number of examples: 10000\n",
      "First example:\n",
      "Line: 001011001110101111111011\n",
      "Label: 0\n",
      "\n",
      "Generating multitask_sparse_parity with 4 tasks, 20 task bits, k=3, noise=0.1\n",
      "[[0 0 0 ... 0 1 0]\n",
      " [1 0 0 ... 0 1 1]\n",
      " [1 0 0 ... 0 0 1]\n",
      " ...\n",
      " [0 0 1 ... 1 1 0]\n",
      " [0 0 0 ... 1 0 1]\n",
      " [1 0 0 ... 0 1 0]] [[0 0 0 ... 0 1 0]\n",
      " [1 0 0 ... 1 1 1]\n",
      " [1 0 0 ... 0 0 1]\n",
      " ...\n",
      " [0 0 1 ... 1 1 0]\n",
      " [0 0 0 ... 1 0 1]\n",
      " [1 0 0 ... 0 1 0]] [array([ 2,  3, 13]), array([ 8, 13, 19]), array([ 2,  3, 17]), array([ 0,  7, 18])]\n",
      "Task 1 uses bits [ 2  3 13] for parity calculation\n",
      "Task 2 uses bits [ 8 13 19] for parity calculation\n",
      "Task 3 uses bits [ 2  3 17] for parity calculation\n",
      "Task 4 uses bits [ 0  7 18] for parity calculation\n",
      "[array([ 2,  3, 13]), array([ 8, 13, 19]), array([ 2,  3, 17]), array([ 0,  7, 18])]\n",
      "\n",
      "Comparing original and noisy examples:\n",
      "Example 1:\n",
      "Original task bits: [0 0 0 1 1 0 0 0 0 0 1 1 1 1 0 1 0 1 0 1]\n",
      "Noisy task bits: [0 0 0 1 1 0 0 0 0 1 1 1 1 1 0 1 0 0 0 1]\n",
      "Bits flipped: 2\n",
      "Original answer: 0\n",
      "Noisy answer: 0\n",
      "\n",
      "Example 2:\n",
      "Original task bits: [1 0 1 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 0 1]\n",
      "Noisy task bits: [1 1 0 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 1 1]\n",
      "Bits flipped: 3\n",
      "Original answer: 1\n",
      "Noisy answer: 1\n",
      "\n",
      "\n",
      "Printing saved data:\n",
      "\n",
      "Contents of noiseless_train.pkl:\n",
      "Number of examples: 2000\n",
      "First example:\n",
      "Line: 000100011000001111010101\n",
      "Label: 0\n",
      "\n",
      "Contents of noiseless_val.pkl:\n",
      "Number of examples: 10000\n",
      "First example:\n",
      "Line: 001011001110100111111010\n",
      "Label: 0\n",
      "\n",
      "Contents of train.pkl:\n",
      "Number of examples: 2000\n",
      "First example:\n",
      "Line: 000100011000011111010001\n",
      "Label: 0\n",
      "\n",
      "Contents of val.pkl:\n",
      "Number of examples: 10000\n",
      "First example:\n",
      "Line: 001011001110101111111011\n",
      "Label: 0\n",
      "\n",
      "Generating multitask_sparse_parity with 4 tasks, 20 task bits, k=3, noise=0.2\n",
      "[[0 0 0 ... 0 1 0]\n",
      " [1 0 0 ... 0 1 1]\n",
      " [1 0 0 ... 0 0 1]\n",
      " ...\n",
      " [0 0 1 ... 1 1 0]\n",
      " [0 0 0 ... 1 0 1]\n",
      " [1 0 0 ... 0 1 0]] [[0 0 0 ... 0 0 0]\n",
      " [1 0 0 ... 1 1 1]\n",
      " [1 0 0 ... 0 0 1]\n",
      " ...\n",
      " [0 0 1 ... 1 1 0]\n",
      " [0 0 0 ... 1 1 1]\n",
      " [1 0 0 ... 0 1 0]] [array([ 2,  3, 13]), array([ 8, 13, 19]), array([ 2,  3, 17]), array([ 0,  7, 18])]\n",
      "Task 1 uses bits [ 2  3 13] for parity calculation\n",
      "Task 2 uses bits [ 8 13 19] for parity calculation\n",
      "Task 3 uses bits [ 2  3 17] for parity calculation\n",
      "Task 4 uses bits [ 0  7 18] for parity calculation\n",
      "[array([ 2,  3, 13]), array([ 8, 13, 19]), array([ 2,  3, 17]), array([ 0,  7, 18])]\n",
      "\n",
      "Comparing original and noisy examples:\n",
      "Example 1:\n",
      "Original task bits: [0 0 0 1 1 0 0 0 0 0 1 1 1 1 0 1 0 1 0 1]\n",
      "Noisy task bits: [0 0 0 1 1 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0]\n",
      "Bits flipped: 4\n",
      "Original answer: 0\n",
      "Noisy answer: 0\n",
      "\n",
      "Example 2:\n",
      "Original task bits: [1 0 1 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 0 1]\n",
      "Noisy task bits: [1 1 0 0 1 0 1 0 0 1 1 0 0 1 1 1 0 0 1 1]\n",
      "Bits flipped: 5\n",
      "Original answer: 1\n",
      "Noisy answer: 1\n",
      "\n",
      "\n",
      "Printing saved data:\n",
      "\n",
      "Contents of noiseless_train.pkl:\n",
      "Number of examples: 2000\n",
      "First example:\n",
      "Line: 000100011000001111010101\n",
      "Label: 0\n",
      "\n",
      "Contents of noiseless_val.pkl:\n",
      "Number of examples: 10000\n",
      "First example:\n",
      "Line: 001011001110100111111010\n",
      "Label: 0\n",
      "\n",
      "Contents of train.pkl:\n",
      "Number of examples: 2000\n",
      "First example:\n",
      "Line: 000100011000011111000000\n",
      "Label: 0\n",
      "\n",
      "Contents of val.pkl:\n",
      "Number of examples: 10000\n",
      "First example:\n",
      "Line: 001001001110101110111111\n",
      "Label: 0\n",
      "\n",
      "Generating multitask_sparse_parity with 4 tasks, 20 task bits, k=3, noise=0.3\n",
      "[[0 0 0 ... 0 1 0]\n",
      " [1 0 0 ... 0 1 1]\n",
      " [1 0 0 ... 0 0 1]\n",
      " ...\n",
      " [0 0 1 ... 1 1 0]\n",
      " [0 0 0 ... 1 0 1]\n",
      " [1 0 0 ... 0 1 0]] [[0 0 0 ... 0 0 0]\n",
      " [1 0 0 ... 1 1 1]\n",
      " [1 0 0 ... 0 0 1]\n",
      " ...\n",
      " [0 0 1 ... 0 1 0]\n",
      " [0 0 0 ... 1 1 1]\n",
      " [1 0 0 ... 0 1 0]] [array([ 2,  3, 13]), array([ 8, 13, 19]), array([ 2,  3, 17]), array([ 0,  7, 18])]\n",
      "Task 1 uses bits [ 2  3 13] for parity calculation\n",
      "Task 2 uses bits [ 8 13 19] for parity calculation\n",
      "Task 3 uses bits [ 2  3 17] for parity calculation\n",
      "Task 4 uses bits [ 0  7 18] for parity calculation\n",
      "[array([ 2,  3, 13]), array([ 8, 13, 19]), array([ 2,  3, 17]), array([ 0,  7, 18])]\n",
      "\n",
      "Comparing original and noisy examples:\n",
      "Example 1:\n",
      "Original task bits: [0 0 0 1 1 0 0 0 0 0 1 1 1 1 0 1 0 1 0 1]\n",
      "Noisy task bits: [0 1 0 1 1 0 0 0 0 1 1 0 0 1 0 0 0 0 0 0]\n",
      "Bits flipped: 7\n",
      "Original answer: 0\n",
      "Noisy answer: 0\n",
      "\n",
      "Example 2:\n",
      "Original task bits: [1 0 1 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 0 1]\n",
      "Noisy task bits: [1 1 0 0 1 0 0 0 0 1 1 0 0 1 1 1 0 0 1 1]\n",
      "Bits flipped: 6\n",
      "Original answer: 1\n",
      "Noisy answer: 1\n",
      "\n",
      "\n",
      "Printing saved data:\n",
      "\n",
      "Contents of noiseless_train.pkl:\n",
      "Number of examples: 2000\n",
      "First example:\n",
      "Line: 000100011000001111010101\n",
      "Label: 0\n",
      "\n",
      "Contents of noiseless_val.pkl:\n",
      "Number of examples: 10000\n",
      "First example:\n",
      "Line: 001011001110100111111010\n",
      "Label: 0\n",
      "\n",
      "Contents of train.pkl:\n",
      "Number of examples: 2000\n",
      "First example:\n",
      "Line: 000101011000011001000000\n",
      "Label: 0\n",
      "\n",
      "Contents of val.pkl:\n",
      "Number of examples: 10000\n",
      "First example:\n",
      "Line: 001001001110001000111111\n",
      "Label: 0\n"
     ]
    }
   ],
   "source": [
    "# sys.path.append(\"../src\")\n",
    "\n",
    "n_tasks = 4  # Number of tasks\n",
    "task_bits_length = 20  # Length of the task bits portion\n",
    "k = 3  # Size of subset for parity calculation\n",
    "n_train = 2000  # Number of training examples\n",
    "n_val = 10000  # Number of validation examples\n",
    "p_bitflips = [0.05, 0.1, 0.2, 0.3]  # Array of bit flip probabilities\n",
    "seed = 1234  # Random seed\n",
    "\n",
    "for p_bitflip in p_bitflips:\n",
    "    # Generate the dataset\n",
    "    print(f\"\\nGenerating multitask_sparse_parity with {n_tasks} tasks, {task_bits_length} task bits, k={k}, noise={p_bitflip}\")\n",
    "    X, Z, task_subsets = generate_multitask_sparse_majority(\n",
    "        n_tasks=n_tasks,\n",
    "        n_bits=task_bits_length,\n",
    "        k=k,\n",
    "        n_data=n_train + n_val,\n",
    "        p_bitflip=p_bitflip,\n",
    "        seed=seed\n",
    "    )\n",
    "\n",
    "    # Print the task subsets (which bits are used for each task)\n",
    "    for i, subset in enumerate(task_subsets):\n",
    "        print(f\"Task {i+1} uses bits {subset} for parity calculation\")\n",
    "    print(task_subsets)\n",
    "    # Split into train and validation sets\n",
    "    X_train = X[:n_train]\n",
    "    X_val = X[n_train:]\n",
    "    Z_train = Z[:n_train]\n",
    "    Z_val = Z[n_train:]\n",
    "\n",
    "    # Verify some examples\n",
    "    # print(\"\\nVerifying noiseless training examples:\")\n",
    "    # verify_examples(X_train, task_subsets, n_tasks, 2)\n",
    "\n",
    "    # if p_bitflip > 0:\n",
    "    #     print(\"\\nVerifying noisy training examples:\")\n",
    "    #     verify_examples(Z_train, task_subsets, n_tasks, 2)\n",
    "\n",
    "    # Compare original and noisy examples\n",
    "    if p_bitflip > 0:\n",
    "        print(\"\\nComparing original and noisy examples:\")\n",
    "        for i in range(2):  # Show 2 examples\n",
    "            print(f\"Example {i+1}:\")\n",
    "            print(f\"Original task bits: {X_train[i, n_tasks:-1]}\")\n",
    "            print(f\"Noisy task bits: {Z_train[i, n_tasks:-1]}\")\n",
    "            print(f\"Bits flipped: {np.sum(X_train[i, n_tasks:-1] != Z_train[i, n_tasks:-1])}\")\n",
    "            print(f\"Original answer: {X_train[i, -1]}\")\n",
    "            print(f\"Noisy answer: {Z_train[i, -1]}\\n\")\n",
    "\n",
    "    # Create directory for this bit flip rate\n",
    "    bf_str = '0' if p_bitflip == 0 else str(int(p_bitflip * 100))\n",
    "    output_dir = f'./multitask_sparse_majority_ntasks{task_bits_length}_ncontrol{n_tasks}_k{k}_ndata{n_train}_bf{bf_str}_seed{seed}'\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "    # Save the datasets\n",
    "    data_io.save_numpy_as_dict(X_train, f'{output_dir}/noiseless_train.pkl')\n",
    "    data_io.save_numpy_as_dict(X_val, f'{output_dir}/noiseless_val.pkl')\n",
    "    \n",
    "    # For zero bit flip, save noiseless data as train.pkl and val.pkl as well\n",
    "    if p_bitflip == 0:\n",
    "        data_io.save_numpy_as_dict(X_train, f'{output_dir}/train.pkl')\n",
    "        data_io.save_numpy_as_dict(X_val, f'{output_dir}/val.pkl')\n",
    "    else:\n",
    "        data_io.save_numpy_as_dict(Z_train, f'{output_dir}/train.pkl')\n",
    "        data_io.save_numpy_as_dict(Z_val, f'{output_dir}/val.pkl')\n",
    "    \n",
    "    # Save task subsets\n",
    "    with open(f'{output_dir}/task_subsets.pkl', 'wb') as f:\n",
    "        pickle.dump(task_subsets, f)\n",
    "\n",
    "    # Print saved data\n",
    "    print(\"\\nPrinting saved data:\")\n",
    "    for filename in ['noiseless_train.pkl', 'noiseless_val.pkl', 'train.pkl', 'val.pkl']:\n",
    "        filepath = f'{output_dir}/{filename}'\n",
    "        if os.path.exists(filepath):\n",
    "            with open(filepath, 'rb') as f:\n",
    "                data = pickle.load(f)\n",
    "                print(f\"\\nContents of {filename}:\")\n",
    "                print(f\"Number of examples: {len(data['line'])}\")\n",
    "                print(\"First example:\")\n",
    "                print(f\"Line: {data['line'][0]}\")\n",
    "                print(f\"Label: {data['label'][0]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[110], line 38\u001b[0m\n\u001b[1;32m     34\u001b[0m majority \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(relevant_bits) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m \n\u001b[1;32m     36\u001b[0m label \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(labels[i])\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m majority \u001b[38;5;241m==\u001b[39m label\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "folder_path = \"multitask_sparse_majority_ntasks20_ncontrol4_k3_ndata2000_bf5_seed1234\"\n",
    "noisy_data_path = folder_path + \"/val.pkl\"\n",
    "noiseless_data_path = folder_path + \"/noiseless_val.pkl\"\n",
    "task_path = folder_path + \"/task_subsets.pkl\"\n",
    "\n",
    "with open(noisy_data_path, 'rb') as f:\n",
    "    noisy_data = pickle.load(f)\n",
    "with open(noiseless_data_path, 'rb') as f:\n",
    "    noiseless_data = pickle.load(f)\n",
    "with open(task_path, 'rb') as f:\n",
    "    task_subsets = pickle.load(f)\n",
    "  \n",
    "labels = noisy_data['label']\n",
    "data = noisy_data['line']\n",
    "\n",
    "\n",
    "\n",
    "  \n",
    "encoding_map = {\n",
    "    '0001': task_subsets[3],\n",
    "    '0010': task_subsets[2],\n",
    "    '0100': task_subsets[1],\n",
    "    '1000': task_subsets[0]\n",
    "}\n",
    "\n",
    "\n",
    "for i in range(len(data)):\n",
    "  one_hot_encoding = (data[i][:4])\n",
    "  relevant_indices = encoding_map[one_hot_encoding]\n",
    "  \n",
    "  task_bits = data[i][4:] \n",
    "  relevant_bits = [int(task_bits[i]) for i in relevant_indices]  \n",
    "\n",
    "  majority = sum(relevant_bits) >= 2 \n",
    "\n",
    "  label = int(labels[i])\n",
    "\n",
    "  assert majority == label\n",
    "\n",
    "  \n",
    "  \n",
    "  \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "autobots",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
