{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4143 problem lines\n",
      "class 5527\n",
      "education 5245\n",
      "marital-status 4754\n",
      "occupation 4680\n",
      "relationship 4573\n",
      "race 4241\n",
      "sex 4084\n",
      "native-country 3986\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd \n",
    "import re\n",
    "import os\n",
    "import json\n",
    "pathreal = '/home/sonia/be_great/data/adult/latest'\n",
    "path = '/mnt/data/sonia/ckpts/adult/dgpt2/plain-aug16/'\n",
    "\n",
    "raws = []\n",
    "with open(os.path.join(path, 'samples.txt'), 'r') as f:\n",
    "    for raw in f.readlines():\n",
    "        raws.append(re.sub('is\\?', 'is ?', raw))\n",
    "\n",
    "real = pd.read_csv(os.path.join(pathreal, 'all.csv'))\n",
    "cols  = set(real.columns)\n",
    " \n",
    "def parse_line(l):\n",
    "    entries = l[:-1].split('.<EOS>') # remove newline at end\n",
    "    # print(entries)\n",
    "    words = [c.split(' ') for c in entries] #'name', 'is', 'value'\n",
    "    # print(words)\n",
    "    d = {c[0]:c[2] for c in words if len(c)==3 and c[0] in cols}\n",
    "    if set(d.keys()) == cols:\n",
    "        return d \n",
    "    else:\n",
    "        return None\n",
    "\n",
    "line_dicts = [parse_line(l) for l in raws]\n",
    "line_dicts = [l for l in line_dicts if l is not None]\n",
    "print(len(raws)-len(line_dicts), 'problem lines')\n",
    "df = pd.DataFrame.from_records(line_dicts)\n",
    "\n",
    "with open(os.path.join(pathreal, 'config.json'), 'r') as f:\n",
    "    dataconfig = json.load(f)\n",
    "ords = dataconfig['ords']\n",
    "\n",
    "ordvals = {col:set(real[col].unique()) for col in ords}\n",
    "for col in ordvals:\n",
    "    ordvals[col] = [str(val).strip() for val in ordvals[col]]\n",
    "\n",
    "for col in ordvals:\n",
    "    df = df[df[col].isin(ordvals[col])]\n",
    "    print(col, len(df))\n",
    "    \n",
    "df.to_csv(os.path.join(path, 'samplesclean.csv'), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "great",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
