{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import glob\n",
    "import copy\n",
    "from scipy.stats import pearsonr\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from pytorch_lightning.trainer.supporters import CombinedLoader\n",
    "from collections import defaultdict\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "import yaml\n",
    "from torch_geometric.data import Batch\n",
    "from torch_geometric.nn import global_mean_pool\n",
    "from models import get_test_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['vs', 'hv']\n",
    "batch_size = 512\n",
    "number_of_folds = 4\n",
    "epochs = 600\n",
    "device = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lookup = torch.load('data/train_set_lookup_table.pt')\n",
    "lookup_test = torch.load('data/test_set_lookup_table.pt')\n",
    "predata = torch.load('data/train_set_preprocessed_data.pt')\n",
    "predata_test = torch.load('data/test_set_preprocessed_data.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "data['train'] = {}\n",
    "data['val'] = {}\n",
    "data['test'] = {}\n",
    "\n",
    "latents = {}\n",
    "latents['train'] = {}\n",
    "latents['val'] = {}\n",
    "latents['test'] = {}\n",
    "\n",
    "for task in tasks:\n",
    "    data_length = len(lookup[~lookup[task].isna()])\n",
    "    data['train'][task] = {}\n",
    "    data['val'][task] = {}\n",
    "    data['test'][task] = {}\n",
    "    idx = np.array_split(list(range(data_length)), number_of_folds)\n",
    "    current_data = lookup[~lookup[task].isna()][['SMILES', task]].reset_index(drop = True)\n",
    "    for fold in range(4):\n",
    "        data['val'][task][fold] = current_data.iloc[idx[fold]]\n",
    "        data['train'][task][fold] = current_data[~current_data[~current_data.isin(data['val'][task])].SMILES.isna()]\n",
    "        data['test'][task][fold] = lookup_test[~lookup_test[task].isna()][['SMILES', task]]\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MolDataset(Dataset):\n",
    " \n",
    "  def __init__(self, df, predata):\n",
    "    x = predata\n",
    "    y = df.iloc[:,1].values\n",
    "    smiles = df.iloc[:,0].values.tolist()\n",
    " \n",
    "    self.x_train = x\n",
    "    self.y_train = torch.tensor(y,dtype=torch.float32)\n",
    "    self.smiles_train = smiles\n",
    " \n",
    "  def __len__(self):\n",
    "    return len(self.y_train)\n",
    "   \n",
    "  def __getitem__(self,idx):\n",
    "    return self.x_train[idx],self.y_train[idx],self.smiles_train[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_loss = nn.MSELoss()\n",
    "\n",
    "train_dataloaders = {}\n",
    "val_dataloaders = {}\n",
    "test_dataloaders = {}\n",
    "\n",
    "for task in tasks:\n",
    "    print(task)\n",
    "    train_dataloaders[task] = {}\n",
    "    val_dataloaders[task] = {}\n",
    "    test_dataloaders[task] = {}\n",
    "    for fold in range(number_of_folds):\n",
    "        train_dataloaders[task][fold] = DataLoader(MolDataset(data['train'][task][fold], predata[data['train'][task][fold].index.tolist()]),batch_size=batch_size,shuffle=True)\n",
    "        val_dataloaders[task][fold] = DataLoader(MolDataset(data['val'][task][fold], predata[data['val'][task][fold].index.tolist()]),batch_size=batch_size,shuffle=True)\n",
    "        test_dataloaders[task][fold] = DataLoader(MolDataset(data['test'][task][fold], predata[data['test'][task][fold].index.tolist()]),batch_size=batch_size,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_kwargs = {'backbone_kwargs': {'depth': {'back': 2, 'front': 2},\n",
    "                'dropout': 0,\n",
    "                'heads': 1,\n",
    "                'hidden': [200],\n",
    "                'in_f': [134, 149],\n",
    "                'out_f': 100},\n",
    "                'bottleneck_kwargs': {'dropout': 0, 'hidden': [50], 'in_f': 100},\n",
    "                'head_kwargs': {'dropout': 0.2, 'hidden': [25, 12], 'in_f': 50, 'out_f': 1},\n",
    "                'transform_kwargs': {'dropout': 0.2,\n",
    "                'latent_size': 50,\n",
    "                'net_width': [100, 100, 100]},\n",
    "                'setup': 'gt',\n",
    "                'task_names': ['vs', 'hv'],\n",
    "                'd_num': 10,\n",
    "                'perb_ratio': 1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_test_model(**config_kwargs)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = {}\n",
    "val_dl = {}\n",
    "test_dl = {}\n",
    "\n",
    "for fold in range(number_of_folds):\n",
    "    iterables = {}\n",
    "    iterables['train'] = {}\n",
    "    iterables['val'] = {}\n",
    "    iterables['test'] = {}   \n",
    "    for task in tasks:\n",
    "        iterables['train'][task] = train_dataloaders[task][fold]\n",
    "        iterables['val'][task] = val_dataloaders[task][fold]\n",
    "        iterables['test'][task] = test_dataloaders[task][fold]\n",
    "\n",
    "    train_dl[fold] = CombinedLoader(iterables['train'])\n",
    "    val_dl[fold] = CombinedLoader(iterables['val']) \n",
    "    test_dl[fold] = CombinedLoader(iterables['test'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha, beta, gamma, delta = 1,1,1,1\n",
    "\n",
    "result = defaultdict(dict)\n",
    "Best_models = {}\n",
    "z_dict = {}\n",
    "\n",
    "for fold in range(number_of_folds):\n",
    "    loss_log = []\n",
    "    val_loss_log = []\n",
    "    all_loss_log = []\n",
    "    all_val_loss_log = []\n",
    "    best_epoch = 0\n",
    "    best_val_loss = 99999999\n",
    "\n",
    "    current_model = copy.deepcopy(model)\n",
    "    optimizer = optim.Adam(current_model.parameters(),lr=5.0e-05)\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        cur_loss_log = []\n",
    "        cur_all_loss_log = []\n",
    "        current_model.train()\n",
    "        for batchs in train_dl[fold]:\n",
    "            reg_loss_all = []\n",
    "            for task in tasks:\n",
    "                batch = batchs[task]\n",
    "                x = batch[0].to(device)\n",
    "                y = batch[1].to(device)\n",
    "                smiles = batch[2]\n",
    "\n",
    "                pre_result_dict = current_model(x, task)\n",
    "                reg_loss = mse_loss(pre_result_dict[task]['down'].squeeze(), y)\n",
    "                ae_loss = mse_loss(torch.stack(pre_result_dict[task]['y_de']), torch.stack(pre_result_dict[task]['decoder']))\n",
    "\n",
    "                tr_ae_loss = mse_loss(torch.stack(pre_result_dict[task]['encoder']), torch.stack(pre_result_dict[task]['ori']))\n",
    "                \n",
    "                mapping_loss = mse_loss(pre_result_dict[task]['map_down'].squeeze(), y)\n",
    "\n",
    "                sub_task = [a for a in tasks if a != task][0]\n",
    "\n",
    "                y_dis = ((torch.stack(pre_result_dict[task]['flat'][1:]) - pre_result_dict[task]['flat'][0] + 0.0000001) ** 2).sum(-1).sqrt()\n",
    "                dis = ((torch.stack(pre_result_dict[sub_task]['flat'][1:]) - pre_result_dict[sub_task]['flat'][0] + 0.0000001) ** 2).sum(-1).sqrt()\n",
    "                dis_loss = mse_loss(dis, y_dis)\n",
    "\n",
    "                cons_loss = mse_loss(torch.stack(pre_result_dict[sub_task]['flat']), torch.stack(pre_result_dict[task]['flat']))\n",
    "\n",
    "                loss = reg_loss + ae_loss + alpha*tr_ae_loss + beta*mapping_loss + gamma*cons_loss + delta*dis_loss\n",
    "                reg_loss_all.append(reg_loss.cpu().detach().item())\n",
    "\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "            cur_loss_log.append(sum(reg_loss_all))\n",
    "            cur_all_loss_log.append(loss.cpu().detach().item())\n",
    "        loss_log.append(np.mean(cur_loss_log))\n",
    "        all_loss_log.append(np.mean(cur_all_loss_log))\n",
    "\n",
    "        cur_val_loss_log = []\n",
    "        cur_val_all_loss_log = []\n",
    "        current_model.eval()\n",
    "        for batchs in val_dl[fold]:\n",
    "            reg_loss_all = []\n",
    "            for task in tasks:\n",
    "                batch = batchs[task]\n",
    "                x = batch[0].to(device)\n",
    "                y = batch[1].to(device)\n",
    "                smiles = batch[2]\n",
    "\n",
    "                pre_result_dict = current_model(x, task)\n",
    "                reg_loss = mse_loss(pre_result_dict[task]['down'].squeeze(), y)\n",
    "                ae_loss = mse_loss(torch.stack(pre_result_dict[task]['y_de']), torch.stack(pre_result_dict[task]['decoder']))\n",
    "\n",
    "                tr_ae_loss = mse_loss(torch.stack(pre_result_dict[task]['encoder']), torch.stack(pre_result_dict[task]['ori']))\n",
    "                \n",
    "                mapping_loss = mse_loss(pre_result_dict[task]['map_down'].squeeze(), y)\n",
    "\n",
    "                sub_task = [a for a in tasks if a != task][0]\n",
    "\n",
    "                y_dis = ((torch.stack(pre_result_dict[task]['flat'][1:]) - pre_result_dict[task]['flat'][0] + 0.0000001) ** 2).sum(-1).sqrt()\n",
    "                dis = ((torch.stack(pre_result_dict[sub_task]['flat'][1:]) - pre_result_dict[sub_task]['flat'][0] + 0.0000001) ** 2).sum(-1).sqrt()\n",
    "                dis_loss = mse_loss(dis, y_dis)\n",
    "\n",
    "                cons_loss = mse_loss(torch.stack(pre_result_dict[sub_task]['flat']), torch.stack(pre_result_dict[task]['flat']))\n",
    "\n",
    "                loss = reg_loss + ae_loss + alpha*tr_ae_loss + beta*mapping_loss + gamma*cons_loss + delta*dis_loss\n",
    "                reg_loss_all.append(reg_loss.cpu().detach().item())\n",
    "\n",
    "            cur_val_loss_log.append(sum(reg_loss_all))\n",
    "            cur_val_all_loss_log.append(loss.cpu().detach().item())\n",
    "        val_loss_log.append(np.mean(cur_val_loss_log))\n",
    "        all_val_loss_log.append(np.mean(cur_val_all_loss_log))\n",
    "\n",
    "        if val_loss_log[-1] < best_val_loss:\n",
    "            best_val_loss = val_loss_log[-1]\n",
    "            best_epoch = epoch\n",
    "            Best_models[fold] = copy.deepcopy(current_model)\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            print(epoch, loss_log[-1],  val_loss_log[-1])\n",
    "    pred_result = {}\n",
    "    gts_result = {}\n",
    "    for task in tasks:\n",
    "        preds = []\n",
    "        gts = []\n",
    "        for batchs in test_dl[fold]:\n",
    "            batch = batchs[task]\n",
    "\n",
    "            x = batch[0].to(device)\n",
    "            pred = Best_models[fold](x, task)[task]['down'].squeeze()\n",
    "            preds = preds + pred.cpu().detach().squeeze().tolist()\n",
    "            gts = gts + batch[1].cpu().detach().squeeze().tolist()\n",
    "\n",
    "        pred_result[task] = preds\n",
    "        gts_result[task] = gts\n",
    "            \n",
    "    result['train_reg_loss'][fold] = loss_log\n",
    "    result['train_all_loss'][fold] = all_loss_log\n",
    "    result['val_reg_loss'][fold] = val_loss_log\n",
    "    result['val_all_loss'][fold] = all_val_loss_log\n",
    "    result['best_val_loss'][fold] = best_val_loss\n",
    "    result['best_epoch'][fold] = best_epoch\n",
    "    result['test_corr'][fold] = {}\n",
    "    result['test_mse'][fold] = {}\n",
    "    \n",
    "    for task in tasks:\n",
    "        print(task)\n",
    "        result['test_corr'][fold][task] = pearsonr(pred_result[task], gts_result[task])[0] \n",
    "        result['test_mse'][fold][task] = mean_squared_error(pred_result[task], gts_result[task])\n",
    "        print('final_all_loss:', all_loss_log[-1], 'best_val_loss:', best_val_loss, ', best_epoch:', best_epoch, ', test_corr:', result['test_corr'][fold][task], ', test_mse:', result['test_mse'][fold][task])\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
