{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "283b26da",
   "metadata": {},
   "source": [
    "### Input data: \n",
    "    \n",
    "1. Install the TDC python package using instructions [here](https://tdcommons.ai/start/). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e75a0805-5ea6-4b59-baf6-eec03870ddef",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tdc.benchmark_group.tcrepitope_group import TCREpitopeGroup\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ab2e9a36-25bf-4ae1-b7e4-eef8589b42b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "group = TCREpitopeGroup()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "00731459-808b-4afc-8a24-e4354877cd3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n",
      "Found local copy...\n"
     ]
    }
   ],
   "source": [
    "train, val = group.get_train_valid_split() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09e4b10-93a4-41ed-b487-e2b8cc37328e",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = group.get_test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e3fba5c9-ec2a-4373-8732-037af60445bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(522076, 5) (71915, 5)\n",
      "(498317, 5) (48329, 5)\n",
      "(522253, 5) (71615, 5)\n",
      "(496174, 5) (49934, 5)\n",
      "(520403, 5) (73541, 5)\n",
      "(498288, 5) (49325, 5)\n",
      "(523618, 5) (70236, 5)\n",
      "(498899, 5) (47491, 5)\n",
      "(522847, 5) (71023, 5)\n",
      "(498916, 5) (50514, 5)\n"
     ]
    }
   ],
   "source": [
    "columns_to_keep = ['cdr3.alpha', 'cdr3.beta', 'antigen.epitope', 'mhc.seq', 'Y']\n",
    "\n",
    "for split in range(0,5):\n",
    "    train_sampled = pd.concat([train['tchard_pep_cdr3b_only_sampled_negs_train'][split],\n",
    "                              train['tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train'][split]], \n",
    "                              ignore_index=True)\n",
    "    train_neg = pd.concat([train['tchard_pep_cdr3b_only_neg_assays'][split],\n",
    "                           train['tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train'][split]], \n",
    "                          ignore_index=True)\n",
    "\n",
    "    test_sampled = pd.concat([test['tchard_pep_cdr3b_only_sampled_negs_train'][split],\n",
    "                           test['tchard_pep_cdr3b_cdr3a_mhc_only_sampled_negs_train'][split]], \n",
    "                          ignore_index=True)\n",
    "\n",
    "    test_neg = pd.concat([test['tchard_pep_cdr3b_only_neg_assays'][split],\n",
    "                           test['tchard_pep_cdr3b_cdr3a_mhc_only_neg_assays_train'][split]], \n",
    "                          ignore_index=True)\n",
    "\n",
    "    train_sampled = train_sampled.drop_duplicates()[columns_to_keep]\n",
    "    train_neg = train_neg.drop_duplicates()[columns_to_keep]\n",
    "    test_sampled = test_sampled.drop_duplicates()[columns_to_keep]\n",
    "    test_neg = test_neg.drop_duplicates()[columns_to_keep]\n",
    "\n",
    "    print(train_sampled.shape, test_sampled.shape)\n",
    "    print(train_neg.shape, test_neg.shape)\n",
    "    \n",
    "    # train_sampled.to_csv(f'./processed_data/RN/train_{split}.csv')\n",
    "    # test_sampled.to_csv(f'./processed_data/RN/test_{split}.csv')\n",
    "    \n",
    "    # train_neg.to_csv(f'./processed_data/NA/train_{split}.csv')\n",
    "    # test_neg.to_csv(f'./processed_data/NA/test_{split}.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5639b710-15aa-49a0-999e-98f094ef5b78",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
