{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filter_df\n",
    "file_path = 'KIDPAN_DATA.DAT'\n",
    "separator = '\\t'  \n",
    "\n",
    "df1 = pd.read_csv(file_path, sep=separator, on_bad_lines='skip', nrows=1000)\n",
    "\n",
    "rows = list(range(1000))\n",
    "cols = [433, 295, 25, 26, 176, 201, 199, 4, 168, 169, 170, 171, 308] \n",
    "\n",
    "selected_data1 = df1.iloc[rows, cols]\n",
    "\n",
    "column_name = 'Unnamed: 433'\n",
    "\n",
    "filtered_df = selected_data1.dropna(subset=[column_name])\n",
    "\n",
    "filtered_df.columns = ['TRR_ID_CODE', 'age', 'gender', 'ABO', 'age_don', 'gender_don', 'ABO_don', 'PRA', 'AMIS', 'BMIS', 'DRMIS', 'HLAMIS', 'GSTATUS_KI']\n",
    "\n",
    "filtered_df = filtered_df[filtered_df['GSTATUS_KI'] != '.']\n",
    "\n",
    "filtered_df.to_csv('filtered_df.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "from sklearn.cluster import KMeans\n",
    "from src.lfgp import LFGP\n",
    "from collections import defaultdict\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mismatch_result\n",
    "data = filtered_df\n",
    "# Convert mismatch indicators from object (string) to integer\n",
    "mismatch_cols = ['PRA', 'AMIS', 'BMIS', 'DRMIS', 'HLAMIS']\n",
    "data[mismatch_cols] = data[mismatch_cols].apply(pd.to_numeric, errors='coerce')\n",
    "\n",
    "# Recalculate the quantiles now that the data types are corrected\n",
    "quantiles = data[mismatch_cols].quantile([0.33, 0.66])\n",
    "\n",
    "# Function to classify results based on quantiles for PRA differently\n",
    "def classify_pra(value, high):\n",
    "    if value > high:\n",
    "        return 0  # PRA uses opposite condition, high value classified as 0\n",
    "    else:\n",
    "        return 1  # Low and middle values classified as 1\n",
    "\n",
    "# Function to classify results based on quantiles for other indicators\n",
    "def classify_result(value, low):\n",
    "    if value > low:\n",
    "        return 1  # Other indicators, low value classified as 1\n",
    "    else:\n",
    "        return 0  # Middle and high values classified as 0\n",
    "\n",
    "# Dictionary for mapping indicators to an index\n",
    "indicator_mapping = {'PRA': 1, 'AMIS': 2, 'BMIS': 3, 'DRMIS': 4, 'HLAMIS': 5}\n",
    "\n",
    "# Apply the classification function again\n",
    "results = []\n",
    "for col in mismatch_cols:\n",
    "    # Get quantiles for each column\n",
    "    low, high = quantiles.loc[0.33, col], quantiles.loc[0.66, col]\n",
    "    # Create a new DataFrame with the classified results\n",
    "    temp_df = data[['TRR_ID_CODE']].copy()\n",
    "    temp_df['Mismatch_Indicator'] = col\n",
    "    if col == 'PRA':\n",
    "        temp_df['Transplantation_Result'] = data[col].apply(classify_pra, args=(high,))\n",
    "    else:\n",
    "        temp_df['Transplantation_Result'] = data[col].apply(classify_result, args=(low,))\n",
    "    results.append(temp_df)\n",
    "\n",
    "# Concatenate results for all mismatch indicators and create the final structured data\n",
    "final_data = pd.concat(results, ignore_index=True)\n",
    "final_data['Mismatch_Index'] = final_data['Mismatch_Indicator'].map(indicator_mapping)\n",
    "mismatch_results = final_data[['TRR_ID_CODE', 'Mismatch_Index', 'Transplantation_Result']]\n",
    "kidney_transplant = filtered_df[['TRR_ID_CODE', 'GSTATUS_KI']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "def modify_first_column(df):\n",
    "    df.iloc[:, 0] = df.iloc[:, 0].apply(lambda x: int(x[1:]) if isinstance(x, str) and x.startswith('A') else x)\n",
    "    return df\n",
    "\n",
    "# Modify the first columns\n",
    "mismatch_results_modified = modify_first_column(mismatch_results)\n",
    "kidney_transplant_modified = modify_first_column(kidney_transplant)\n",
    "\n",
    "# Create mappings for the first columns from the modified dataframes\n",
    "mismatch_mapping = {val: idx for idx, val in enumerate(sorted(mismatch_results_modified.iloc[:, 0].unique()))}\n",
    "kidney_mapping = {val: idx for idx, val in enumerate(sorted(kidney_transplant_modified.iloc[:, 0].unique()))}\n",
    "\n",
    "# Apply the mappings to the first columns\n",
    "mismatch_results_final = mismatch_results_modified.copy()\n",
    "kidney_transplant_final = kidney_transplant_modified.copy()\n",
    "\n",
    "mismatch_results_final.iloc[:, 0] = mismatch_results_modified.iloc[:, 0].map(mismatch_mapping)\n",
    "kidney_transplant_final.iloc[:, 0] = kidney_transplant_modified.iloc[:, 0].map(kidney_mapping)\n",
    "\n",
    "# Save the final modified files\n",
    "mismatch_results_final.to_csv(\"mismatch_results_final.txt\", header=False, index=False)\n",
    "kidney_transplant_final.to_csv(\"kidney_transplant_final.txt\", header=False, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of tasks: 544\n",
      "number of workers: 5\n",
      "number of crowd labels: 2720\n"
     ]
    }
   ],
   "source": [
    "rating = np.loadtxt(\"mismatch_results_final.txt\", delimiter=',')\n",
    "label = np.loadtxt('kidney_transplant_final.txt', delimiter=',')\n",
    "\n",
    "print(\"number of tasks: {0}\".format(len(np.unique(rating[:, 0]))))\n",
    "print(\"number of workers: {0}\".format(len(np.unique(rating[:, 1]))))\n",
    "print(\"number of crowd labels: {0}\".format(len(rating[:, 0])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5919117647058824\n"
     ]
    }
   ],
   "source": [
    "model = LFGP(lf_dim=5, n_worker_group=2, lambda1=0.1, lambda2=0.1)\n",
    "model._prescreen(rating)\n",
    "\n",
    "\n",
    "np.random.seed(0) # for the purpose of reproducibility, fix the seed for initialization\n",
    "model._mc_fit(rating, epsilon=1e-4, scheme=\"mv\", maxiter=80, verbose=0)\n",
    "pred_label = model._mc_infer(rating)\n",
    "\n",
    "print(np.mean(np.equal(label[:, 1], pred_label[:, 1])))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
