{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data augmentation process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install rdkit\n",
    "\n",
    "from augmentation import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_smiles_list = ['CC1=C2C=C(C=C(C2=CC=C1)C(=O)O)[O-]']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kekulize = get_kekulizeSmiles(test_smiles_list)\n",
    "#output: ['CC1=C2C=C([O-])C=C(C(=O)O)C2=CC=C1']\n",
    "hydrogen = get_ExplicitHs(test_smiles_list)\n",
    "#output: ['CH3][c]1[cH][cH][cH][c]2[c]([C](=[O])[OH])[cH][c]([O-])[cH][c]12']\n",
    "canonical = get_rdkitCanonical(test_smiles_list)\n",
    "#output: ['Cc1cccc2c(C(=O)O)cc([O-])cc12']\n",
    "cycle = get_cycleRenumering(test_smiles_list)\n",
    "#output: ['CC1=C3C=C(C=C(C3=CC=C1)C(=O)O)[O-]']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model evaluation on augmented data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"laituan245/molt5-large-smiles2caption\", model_max_length=512)\n",
    "model = T5ForConditionalGeneration.from_pretrained('laituan245/molt5-large-smiles2caption')\n",
    "\n",
    "input_text = 'CC1=C2C=C([O-])C=C(C(=O)O)C2=CC=C1'\n",
    "input_ids = tokenizer(input_text, return_tensors=\"pt\").input_ids\n",
    "\n",
    "outputs = model.generate(input_ids, num_beams=5, max_length=512)\n",
    "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(\"GT4SD/multitask-text-and-chemistry-t5-base-augm\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"GT4SD/multitask-text-and-chemistry-t5-base-augm\")\n",
    "\n",
    "instance = \"CC1=C2C=C([O-])C=C(C(=O)O)C2=CC=C1\"\n",
    "input_text = f\"Caption the following molecule: {instance}\"\n",
    "\n",
    "text = tokenizer(input_text, return_tensors=\"pt\")\n",
    "output = model.generate(input_ids=text[\"input_ids\"], max_length=max_length, num_beams=num_beams)\n",
    "output = tokenizer.decode(output[0].cpu())\n",
    "\n",
    "output = output.split(tokenizer.eos_token)[0]\n",
    "output = output.replace(tokenizer.pad_token,\"\")\n",
    "output = output.strip()\n",
    "\n",
    "output"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
