{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from sklearn import preprocessing, pipeline, ensemble, compose\n",
    "import datasets\n",
    "import os\n",
    "\n",
    "paths = {\n",
    "    'real': '/hdd3/sonia/data/adult.csv',\n",
    "    'dgpt2': '/hdd3/sonia/be_great/ckpts/dgpt2/adult-allcol/samples.csv',\n",
    "    'moe': '/hdd3/sonia/be_great/ckpts/moe/dgpt2/adult-allcol/jul21/samplesclean.csv',\n",
    "    'greatdgpt2': '/hdd3/sonia/be_great/ckpts/dgpt2-greatclean.csv',\n",
    "    'moegreatdgpt2': '/hdd3/sonia/be_great/ckpts/great/adult/moegreatdgpt2-aug12.csv',\n",
    "    'mhmoegreatdgpt2-0': '/hdd3/sonia/be_great/ckpts/moemh/dgpt2/adult-allcol/aug14-0.csv',\n",
    "    'mhmoegreatdgpt2-1': '/hdd3/sonia/be_great/ckpts/moemh/dgpt2/adult-allcol/aug14-1.csv',\n",
    "    'mhmoegreatdgpt2-2': '/hdd3/sonia/be_great/ckpts/moemh/dgpt2/adult-allcol/aug14-2.csv',\n",
    "    'fairgan': '/hdd3/sonia/be_great/ckpts/tabfairgan/adult-april.csv',\n",
    "}\n",
    "rs = 4 # random state\n",
    "train_frac = 0.75\n",
    "\n",
    "ords = ['workclass', 'education', 'marital-status', 'occupation', \n",
    "        'relationship', 'race', 'sex', 'native-country'] # MUST BE IN ORDER\n",
    "nums = ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week', ]\n",
    "labs = ['income']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['real: (48842, 15)', 'dgpt2: (5935, 15)', 'moe: (9106, 15)', 'greatdgpt2: (9815, 15)', 'moegreatdgpt2: (9997, 15)', 'mhmoegreatdgpt2-0: (9995, 15)', 'mhmoegreatdgpt2-1: (9996, 15)', 'mhmoegreatdgpt2-2: (9996, 15)', 'fairgan: (32561, 15)']\n",
      "real\n",
      "dgpt2\n",
      "moe\n",
      "greatdgpt2\n",
      "moegreatdgpt2\n",
      "mhmoegreatdgpt2-0\n",
      "mhmoegreatdgpt2-1\n",
      "mhmoegreatdgpt2-2\n",
      "fairgan\n",
      "sampled 5935 rows from each dataset, made trainsets of 4451 rows each\n"
     ]
    }
   ],
   "source": [
    "datadict = {k:pd.read_csv(v) for (k,v) in paths.items()}\n",
    "print([k+': '+str(df.shape) for (k, df) in datadict.items()]) # print shape of each dataset\n",
    "\n",
    "min_dataset_size = min([len(df) for df in datadict.values()])\n",
    "train_size = int(min_dataset_size*train_frac)\n",
    "categoriesdict = dict() # collect all unique values for each of the ordinal columns\n",
    "for (k, df) in datadict.items():\n",
    "    print(k)\n",
    "    # remove extra spaces around strings, eg ' dog' -> 'dog'\n",
    "    df = df.map(lambda x: x.strip() if type(x) == str else x)\n",
    "    for col in ords:\n",
    "        categoriesdict[col] = categoriesdict.get(col, []) + df[col].unique().tolist()\n",
    "        \n",
    "    if k == 'real':\n",
    "        datadict[k] = {'train':df, 'test': df}\n",
    "    else:\n",
    "        # sample to min dataset size, shuffle, ensure cols in order w no extra cols:\n",
    "        df = df.sample(min_dataset_size, random_state=rs, ignore_index=True)[ords+nums+labs]\n",
    "        datadict[k] = {'train':df.iloc[:train_size, :],\n",
    "                    'test': df.iloc[train_size:, :]}\n",
    "print(f'sampled {min_dataset_size} rows from each dataset, made trainsets of {train_size} rows each')\n",
    "\n",
    "categories = []\n",
    "for col in ords:\n",
    "    categories.append(list(set(categoriesdict[col])))\n",
    "ordenc = preprocessing.OrdinalEncoder(categories=categories)\n",
    "numenc = preprocessing.StandardScaler()\n",
    "lb = preprocessing.LabelBinarizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "real\n",
      "dgpt2\n",
      "moe\n",
      "greatdgpt2\n",
      "moegreatdgpt2\n",
      "mhmoegreatdgpt2-0\n",
      "mhmoegreatdgpt2-1\n",
      "mhmoegreatdgpt2-2\n",
      "fairgan\n"
     ]
    }
   ],
   "source": [
    "# make random forest sklearn pipeline\n",
    "def create_pipeline(trainset):\n",
    "    rfc = ensemble.RandomForestClassifier(n_estimators=10, max_depth=4, random_state=rs)\n",
    "    preprocessing_pipeline = compose.ColumnTransformer([\n",
    "        (\"ordinal_preprocessor\", ordenc, ords),\n",
    "        (\"numerical_preprocessor\", numenc, nums),\n",
    "    ])\n",
    "    complete_pipeline = pipeline.Pipeline([\n",
    "        (\"preprocessor\", preprocessing_pipeline),\n",
    "        (\"estimator\", rfc)\n",
    "    ])\n",
    "    \n",
    "    preprocessed_labels = lb.fit_transform(trainset[labs].values.ravel()).ravel()\n",
    "    complete_pipeline.fit(trainset[ords+nums], preprocessed_labels)\n",
    "    return complete_pipeline\n",
    "\n",
    "rfdict = {}\n",
    "for src in datadict.keys():\n",
    "    print(src)\n",
    "    rfdict[src] = create_pipeline(datadict[src]['train'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "real\n",
      "real on real: \t\t\t0.838786290487695\n",
      "dgpt2 on real: \t\t\t0.8287948896441587\n",
      "moe on real: \t\t\t0.7953400761639573\n",
      "greatdgpt2 on real: \t\t\t0.8370459850128987\n",
      "moegreatdgpt2 on real: \t\t\t0.8225912124810614\n",
      "mhmoegreatdgpt2-0 on real: \t\t\t0.8278326030875066\n",
      "mhmoegreatdgpt2-1 on real: \t\t\t0.8190287048032431\n",
      "mhmoegreatdgpt2-2 on real: \t\t\t0.8181073666107039\n",
      "fairgan on real: \t\t\t0.824495311412309\n",
      "\n",
      "\n",
      "dgpt2\n",
      "real on dgpt2: \t\t\t0.7621293800539084\n",
      "dgpt2 on dgpt2: \t\t\t0.7978436657681941\n",
      "moe on dgpt2: \t\t\t0.7668463611859838\n",
      "greatdgpt2 on dgpt2: \t\t\t0.7681940700808625\n",
      "moegreatdgpt2 on dgpt2: \t\t\t0.7405660377358491\n",
      "mhmoegreatdgpt2-0 on dgpt2: \t\t\t0.7506738544474394\n",
      "mhmoegreatdgpt2-1 on dgpt2: \t\t\t0.7216981132075472\n",
      "mhmoegreatdgpt2-2 on dgpt2: \t\t\t0.7142857142857143\n",
      "fairgan on dgpt2: \t\t\t0.7196765498652291\n",
      "\n",
      "\n",
      "moe\n",
      "real on moe: \t\t\t0.7722371967654986\n",
      "dgpt2 on moe: \t\t\t0.7654986522911051\n",
      "moe on moe: \t\t\t0.7964959568733153\n",
      "greatdgpt2 on moe: \t\t\t0.7681940700808625\n",
      "moegreatdgpt2 on moe: \t\t\t0.77088948787062\n",
      "mhmoegreatdgpt2-0 on moe: \t\t\t0.7681940700808625\n",
      "mhmoegreatdgpt2-1 on moe: \t\t\t0.7681940700808625\n",
      "mhmoegreatdgpt2-2 on moe: \t\t\t0.7668463611859838\n",
      "fairgan on moe: \t\t\t0.7722371967654986\n",
      "\n",
      "\n",
      "greatdgpt2\n",
      "real on greatdgpt2: \t\t\t0.8079514824797843\n",
      "dgpt2 on greatdgpt2: \t\t\t0.8032345013477089\n",
      "moe on greatdgpt2: \t\t\t0.6475741239892183\n",
      "greatdgpt2 on greatdgpt2: \t\t\t0.807277628032345\n",
      "moegreatdgpt2 on greatdgpt2: \t\t\t0.8066037735849056\n",
      "mhmoegreatdgpt2-0 on greatdgpt2: \t\t\t0.8039083557951483\n",
      "mhmoegreatdgpt2-1 on greatdgpt2: \t\t\t0.8045822102425876\n",
      "mhmoegreatdgpt2-2 on greatdgpt2: \t\t\t0.8025606469002695\n",
      "fairgan on greatdgpt2: \t\t\t0.8119946091644205\n",
      "\n",
      "\n",
      "moegreatdgpt2\n",
      "real on moegreatdgpt2: \t\t\t0.8450134770889488\n",
      "dgpt2 on moegreatdgpt2: \t\t\t0.8402964959568733\n",
      "moe on moegreatdgpt2: \t\t\t0.7789757412398922\n",
      "greatdgpt2 on moegreatdgpt2: \t\t\t0.8463611859838275\n",
      "moegreatdgpt2 on moegreatdgpt2: \t\t\t0.8456873315363881\n",
      "mhmoegreatdgpt2-0 on moegreatdgpt2: \t\t\t0.8436657681940701\n",
      "mhmoegreatdgpt2-1 on moegreatdgpt2: \t\t\t0.8409703504043127\n",
      "mhmoegreatdgpt2-2 on moegreatdgpt2: \t\t\t0.8423180592991913\n",
      "fairgan on moegreatdgpt2: \t\t\t0.8416442048517521\n",
      "\n",
      "\n",
      "mhmoegreatdgpt2-0\n",
      "real on mhmoegreatdgpt2-0: \t\t\t0.8328840970350404\n",
      "dgpt2 on mhmoegreatdgpt2-0: \t\t\t0.8261455525606469\n",
      "moe on mhmoegreatdgpt2-0: \t\t\t0.7365229110512129\n",
      "greatdgpt2 on mhmoegreatdgpt2-0: \t\t\t0.8342318059299192\n",
      "moegreatdgpt2 on mhmoegreatdgpt2-0: \t\t\t0.8241239892183289\n",
      "mhmoegreatdgpt2-0 on mhmoegreatdgpt2-0: \t\t\t0.828167115902965\n",
      "mhmoegreatdgpt2-1 on mhmoegreatdgpt2-0: \t\t\t0.8227762803234502\n",
      "mhmoegreatdgpt2-2 on mhmoegreatdgpt2-0: \t\t\t0.8241239892183289\n",
      "fairgan on mhmoegreatdgpt2-0: \t\t\t0.8328840970350404\n",
      "\n",
      "\n",
      "mhmoegreatdgpt2-1\n",
      "real on mhmoegreatdgpt2-1: \t\t\t0.862533692722372\n",
      "dgpt2 on mhmoegreatdgpt2-1: \t\t\t0.8530997304582211\n",
      "moe on mhmoegreatdgpt2-1: \t\t\t0.7520215633423181\n",
      "greatdgpt2 on mhmoegreatdgpt2-1: \t\t\t0.8584905660377359\n",
      "moegreatdgpt2 on mhmoegreatdgpt2-1: \t\t\t0.8591644204851752\n",
      "mhmoegreatdgpt2-0 on mhmoegreatdgpt2-1: \t\t\t0.8618598382749326\n",
      "mhmoegreatdgpt2-1 on mhmoegreatdgpt2-1: \t\t\t0.860512129380054\n",
      "mhmoegreatdgpt2-2 on mhmoegreatdgpt2-1: \t\t\t0.8598382749326146\n",
      "fairgan on mhmoegreatdgpt2-1: \t\t\t0.8618598382749326\n",
      "\n",
      "\n",
      "mhmoegreatdgpt2-2\n",
      "real on mhmoegreatdgpt2-2: \t\t\t0.8328840970350404\n",
      "dgpt2 on mhmoegreatdgpt2-2: \t\t\t0.8355795148247979\n",
      "moe on mhmoegreatdgpt2-2: \t\t\t0.7459568733153639\n",
      "greatdgpt2 on mhmoegreatdgpt2-2: \t\t\t0.8355795148247979\n",
      "moegreatdgpt2 on mhmoegreatdgpt2-2: \t\t\t0.8342318059299192\n",
      "mhmoegreatdgpt2-0 on mhmoegreatdgpt2-2: \t\t\t0.8328840970350404\n",
      "mhmoegreatdgpt2-1 on mhmoegreatdgpt2-2: \t\t\t0.8274932614555256\n",
      "mhmoegreatdgpt2-2 on mhmoegreatdgpt2-2: \t\t\t0.8295148247978437\n",
      "fairgan on mhmoegreatdgpt2-2: \t\t\t0.8328840970350404\n",
      "\n",
      "\n",
      "fairgan\n",
      "real on fairgan: \t\t\t0.8234501347708895\n",
      "dgpt2 on fairgan: \t\t\t0.8247978436657682\n",
      "moe on fairgan: \t\t\t0.7931266846361186\n",
      "greatdgpt2 on fairgan: \t\t\t0.8274932614555256\n",
      "moegreatdgpt2 on fairgan: \t\t\t0.8274932614555256\n",
      "mhmoegreatdgpt2-0 on fairgan: \t\t\t0.8221024258760108\n",
      "mhmoegreatdgpt2-1 on fairgan: \t\t\t0.8221024258760108\n",
      "mhmoegreatdgpt2-2 on fairgan: \t\t\t0.8200808625336927\n",
      "fairgan on fairgan: \t\t\t0.8268194070080862\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for data in datadict.keys():\n",
    "    print(data)\n",
    "    labels = lb.fit_transform(datadict[data]['test'][labs])\n",
    "    for model in rfdict.keys():\n",
    "        score = rfdict[model].score(datadict[data]['test'][ords+nums], labels)\n",
    "        print(f'{model} on {data}: \\t\\t\\t{score}')\n",
    "    print('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
