{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e67d0c65-ed29-4bff-bc88-006eeb145d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import glob\n",
    "import random\n",
    "\n",
    "from datasets import load_dataset, Dataset, DatasetDict\n",
    "from huggingface_hub.hf_api import HfFolder\n",
    "HfFolder.save_token('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57f51f35-6490-49b9-ba04-3c0307e25868",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset('')\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24b6d5a5-a65e-4e6d-96e6-47960777478c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = dataset['train'].to_pandas()\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e060c462-e4e5-4b7b-873c-a7bfa7f81cef",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_df = pd.DataFrame(columns=['anchor', 'positive', 'negatives', 'phase'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe3d7ea3-bc90-4f6d-8bbd-91503df844ba",
   "metadata": {},
   "source": [
    "## Query - Question NLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1fb25e9-8faf-4fbd-bac1-e16e05f35def",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_1 = df[['query_syll', 'question_NLP_syll']]\n",
    "df_1.isna().sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb75ef9c-c8ee-488e-be89-d8ed487fd08b",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_df['anchor'] = df_1['query_syll']\n",
    "final_df['positive'] = df_1['question_NLP_syll']\n",
    "final_df['phase'] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d523239-65d8-407c-addf-9794ccc9e64e",
   "metadata": {},
   "source": [
    "## Question NLP - Mini schema <br>Question - Mini schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b55f33ec-13b0-4c4b-816b-aa50da4e0214",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_2 = df[['question_syll', 'question_NLP_syll', 'schema_syll', 'mini_schema_syll']]\n",
    "df_2.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93a509a5-bdc8-44cf-b28a-81cac308f325",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(sample):\n",
    "    # clean schema chatgpt gen\n",
    "    sample = sample.replace('```sql\\n', '').replace('\\n```', '')\n",
    "\n",
    "    # re-format list of table\n",
    "    tables = sample.split('CREATE TABLE')\n",
    "    schema = ['CREATE TABLE ' + table.strip() for table in tables if table]\n",
    "    # re-format -> #####\n",
    "    sample = '#####'.join(schema)\n",
    "\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ddce386-ef7c-4f69-952a-8b9c94af1b99",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_2['mini_schema_syll_reformat'] = df_2['mini_schema_syll'].map(preprocess)\n",
    "df_2['schema_syll_reformat'] = df_2['schema_syll'].map(preprocess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd8f7029-8efd-496f-93be-e753a4ec6a2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_2.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abc4c0c6-d3bc-44cf-a6d4-b8f0ff6b296e",
   "metadata": {},
   "outputs": [],
   "source": [
    "schema_list = []\n",
    "for sample in df_2['schema_syll_reformat'].tolist():\n",
    "    tables = sample.split('#####')\n",
    "    schema_list.extend(tables)\n",
    "\n",
    "schema_list = list(set(schema_list))\n",
    "len(schema_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1654a6ba-e51a-42f9-a8a1-94f913f7071a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_2['hard_negatives'], df_2['negatives'] = None, None\n",
    "for i, row in df_2.iterrows():\n",
    "    schema = row['schema_syll_reformat'].split('#####')\n",
    "    mini_schema = row['mini_schema_syll_reformat'].split('#####')\n",
    "\n",
    "    for table in mini_schema:\n",
    "        if table in schema:\n",
    "            schema.remove(table)\n",
    "    hard_negatives = '#####'.join(schema)\n",
    "\n",
    "    negatives = []\n",
    "    while len(negatives) != 3:\n",
    "        table = random.choice(schema_list)\n",
    "        if (table not in negatives) and (table not in mini_schema) and (table not in schema):\n",
    "            negatives.append(table)\n",
    "    negatives = '#####'.join(negatives)\n",
    "\n",
    "    df_2.at[i, 'hard_negatives'] = hard_negatives\n",
    "    df_2.at[i, 'negatives'] = negatives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d591f54-bb3d-4276-ba9f-489dcce02fb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_2.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccc9989b-1da5-4a5a-ab4c-da87bff04a16",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 1\n",
    "\n",
    "print(f'Anchor: {df_2.iloc[i].question_NLP_syll}')\n",
    "print()\n",
    "print(f'Positives: {df_2.iloc[i].mini_schema_syll_reformat}')\n",
    "print()\n",
    "print(f'Hard negatives: {df_2.iloc[i].hard_negatives}')\n",
    "print()\n",
    "print(f'Negatives: {df_2.iloc[i].negatives}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17046a2a-cf78-4921-af80-8a6d542d9928",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, row in df_2.iterrows():\n",
    "    if (row['hard_negatives'] == '') or pd.isnull(row['hard_negatives']):\n",
    "        df_2.at[i, 'hard_negatives'] = row['negatives']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c554579-7341-41b7-9710-ad867f00775a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_2.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fe7fb33-eacf-4ffc-aff2-a4fad8c53de0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_2_phase_2 = df_2[['question_NLP_syll', 'mini_schema_syll_reformat', 'hard_negatives']]\n",
    "df_2_phase_2.columns = ['anchor', 'positive', 'negatives']\n",
    "df_2_phase_2['phase'] = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18fbdb7a-8998-4fd7-90f5-9062736f0249",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_2_phase_3 = df_2[['question_syll', 'mini_schema_syll_reformat', 'hard_negatives']]\n",
    "df_2_phase_3.columns = ['anchor', 'positive', 'negatives']\n",
    "df_2_phase_3['phase'] = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ce4ee24-b7ec-4ed9-9b32-a25b7e5de6ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_df = pd.concat([final_df, df_2_phase_2, df_2_phase_3], ignore_index=True)\n",
    "final_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95a5c73a-f0cd-4b97-8265-682a668188e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_df.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "085658ad-6e08-4002-9bf0-490d32bc3255",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_df[final_df['positive'] == final_df['negatives']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39c1e0fa-ccec-4bf2-acb7-b792d2eeaa4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = Dataset.from_pandas(final_df)\n",
    "ds_dict = DatasetDict({\n",
    "    'train': ds\n",
    "})\n",
    "\n",
    "ds_dict.push_to_hub('strongpear/text2sql_positive_negatives')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
