{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "11b6ba5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pickle\n",
    "from sentence_transformers import SentenceTransformer\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "279e1002",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(dataset, folder = 'data/wrench_class'):\n",
    "    \n",
    "    path = folder + '/' + dataset + '/' \n",
    "    with open(path + 'train.json', 'r') as f:\n",
    "        df_train = json.load(f)  \n",
    "    with open(path + 'valid.json', 'r') as f:\n",
    "        df_val = json.load(f)   \n",
    "    with open(path + 'test.json', 'r') as f:\n",
    "        df_test = json.load(f)\n",
    "\n",
    "    ## X ##\n",
    "    if 'feature' in df_train[list(df_train.keys())[0]]['data']:\n",
    "        X_train = torch.tensor([df_train[key]['data']['feature'] for key in df_train.keys()]).double()\n",
    "        X_val = torch.tensor([df_val[key]['data']['feature'] for key in df_val.keys()]).double()\n",
    "        X_test = torch.tensor([df_test[key]['data']['feature'] for key in df_test.keys()]).double()\n",
    "        \n",
    "    elif 'text' in df_train[list(df_train.keys())[0]]['data']:\n",
    "        feature_extractor = SentenceTransformer('all-MiniLM-L6-v2')\n",
    "        texts = [df_train[key]['data']['text'] for key in df_train.keys()]\n",
    "        X_train = torch.stack([torch.tensor(feature_extractor.encode([t])[0]) for t in tqdm(texts)]).double()\n",
    "        texts = [df_val[key]['data']['text'] for key in df_val.keys()]\n",
    "        X_val = torch.stack([torch.tensor(feature_extractor.encode([t])[0]) for t in tqdm(texts)]).double()\n",
    "        texts = [df_test[key]['data']['text'] for key in df_test.keys()]\n",
    "        X_test = torch.stack([torch.tensor(feature_extractor.encode([t])[0]) for t in tqdm(texts)]).double()\n",
    "\n",
    "    ind = X_train.std(dim=0)!=0 #excluding cols with no variation in the training set\n",
    "    X_train = X_train[:,ind]\n",
    "    X_val = X_val[:,ind]\n",
    "    X_test = X_test[:,ind]\n",
    "    X_val = (X_val-X_train.mean(dim=0))/X_train.std(dim=0) #standardizing\n",
    "    X_test = (X_test-X_train.mean(dim=0))/X_train.std(dim=0) \n",
    "    X_train = (X_train-X_train.mean(dim=0))/X_train.std(dim=0)\n",
    "\n",
    "    ## Y ##\n",
    "    if dataset=='spouse':\n",
    "        Y_train = torch.tensor([])\n",
    "    else:\n",
    "        Y_train = torch.tensor([df_train[key]['label'] for key in df_train.keys()])\n",
    "    Y_val = torch.tensor([df_val[key]['label'] for key in df_val.keys()])\n",
    "    Y_test = torch.tensor([df_test[key]['label'] for key in df_test.keys()])\n",
    "\n",
    "    ## L ##\n",
    "    L_train = np.array([df_train[key]['weak_labels'] for key in df_train.keys()])\n",
    "    L_val = np.array([df_val[key]['weak_labels'] for key in df_val.keys()])\n",
    "    L_test = np.array([df_test[key]['weak_labels'] for key in df_test.keys()])\n",
    "\n",
    "    ## Save processed data ##\n",
    "    dic = {'X_train': X_train, 'X_val': X_val, 'X_test': X_test, 'Y_train': Y_train, 'Y_val':Y_val, 'Y_test':Y_test, 'L_train':L_train, 'L_val':L_val, 'L_test':L_test}\n",
    "    \n",
    "    with open(folder + '/' + dataset + '/processed_data.pickle', 'wb') as handle:\n",
    "        pickle.dump(dic, handle, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "11f7488f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets = ['census', 'basketball', 'tennis', 'youtube', 'yelp', 'imdb', 'agnews', 'chemprot', 'sms', 'spouse', 'cdr', 'trec', 'semeval']\n",
    "\n",
    "len(datasets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "db19d9ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** census **************************************************\n",
      "\n",
      " ************************************************** basketball **************************************************\n",
      "\n",
      " ************************************************** tennis **************************************************\n",
      "\n",
      " ************************************************** youtube **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 1686/1686 [01:04<00:00, 26.00it/s]\n",
      "100%|█████████████████████████████████████████| 120/120 [00:04<00:00, 25.23it/s]\n",
      "100%|█████████████████████████████████████████| 250/250 [00:10<00:00, 23.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** yelp **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 30400/30400 [53:25<00:00,  9.48it/s]\n",
      "100%|███████████████████████████████████████| 3800/3800 [06:12<00:00, 10.20it/s]\n",
      "100%|███████████████████████████████████████| 3800/3800 [06:26<00:00,  9.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** imdb **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 20000/20000 [48:42<00:00,  6.84it/s]\n",
      "100%|███████████████████████████████████████| 2500/2500 [19:02<00:00,  2.19it/s]\n",
      "100%|███████████████████████████████████████| 2500/2500 [05:14<00:00,  7.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** agnews **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 96000/96000 [53:11<00:00, 30.08it/s]\n",
      "100%|█████████████████████████████████████| 12000/12000 [06:30<00:00, 30.73it/s]\n",
      "100%|█████████████████████████████████████| 12000/12000 [06:35<00:00, 30.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** chemprot **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 12861/12861 [09:17<00:00, 23.09it/s]\n",
      "100%|███████████████████████████████████████| 1607/1607 [01:09<00:00, 23.13it/s]\n",
      "100%|███████████████████████████████████████| 1607/1607 [01:09<00:00, 23.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** sms **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 4571/4571 [01:42<00:00, 44.52it/s]\n",
      "100%|█████████████████████████████████████████| 500/500 [00:11<00:00, 44.30it/s]\n",
      "100%|█████████████████████████████████████████| 500/500 [00:11<00:00, 45.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** spouse **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 22254/22254 [16:36<00:00, 22.34it/s]\n",
      "100%|███████████████████████████████████████| 2811/2811 [02:05<00:00, 22.45it/s]\n",
      "100%|███████████████████████████████████████| 2701/2701 [01:57<00:00, 22.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** cdr **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 8430/8430 [05:20<00:00, 26.33it/s]\n",
      "100%|█████████████████████████████████████████| 920/920 [00:34<00:00, 26.66it/s]\n",
      "100%|███████████████████████████████████████| 4673/4673 [02:52<00:00, 27.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** trec **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 5033/5033 [01:17<00:00, 64.98it/s]\n",
      "100%|█████████████████████████████████████████| 500/500 [00:07<00:00, 65.47it/s]\n",
      "100%|█████████████████████████████████████████| 500/500 [00:07<00:00, 70.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " ************************************************** semeval **************************************************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 1749/1749 [00:39<00:00, 44.23it/s]\n",
      "100%|█████████████████████████████████████████| 178/178 [00:04<00:00, 42.03it/s]\n",
      "100%|█████████████████████████████████████████| 600/600 [00:15<00:00, 38.47it/s]\n"
     ]
    }
   ],
   "source": [
    "for dataset in datasets:\n",
    "    print(\"\\n **************************************************\", dataset, \"**************************************************\")\n",
    "    load_data(dataset, folder = 'data/wrench_class')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4db32a7e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
