{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Actual Code start here\n",
    "\n",
    "Code before here are for debugging."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from VAE import train_vae_model, TabularTokenizerTransformer\n",
    "import pandas as pd\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import sdmetrics\n",
    "from sdmetrics.reports.single_table import QualityReport\n",
    "\n",
    "def reconstruct_dataframe(model, df, transformer, threshold=0.5, device='cpu'):\n",
    "    # Step 1: Convert df to X_num and X_cat using the fitted transformer\n",
    "    X_num, X_cat = transformer.transform(df)\n",
    "    \n",
    "    # Convert to torch tensors and move to device\n",
    "    X_num = torch.tensor(X_num).float().to(device)\n",
    "    X_cat = torch.tensor(X_cat).to(device)\n",
    "\n",
    "    # Step 2: Use the VAE model to produce Recon_X_num and Recon_X_cat\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        Recon_X_num, Recon_X_cat, _, _ = model(X_num, X_cat)\n",
    "\n",
    "    # Step 3: Convert Recon_X_cat (predicted probabilities) back to labels using threshold\n",
    "    Recon_X_cat_labels = []\n",
    "    for i, recon_cat in enumerate(Recon_X_cat):\n",
    "        # Apply threshold to convert probabilities to labels\n",
    "        recon_cat_labels = torch.max(recon_cat, dim=1)[1]\n",
    "        Recon_X_cat_labels.append(recon_cat_labels)\n",
    "    \n",
    "    # Stack Recon_X_cat_labels to match the original shape\n",
    "    Recon_X_cat_labels = torch.stack(Recon_X_cat_labels, dim=1).cpu().numpy()\n",
    "\n",
    "    # Step 4: Inversely transform Recon_X_num and Recon_X_cat_labels back to a pandas DataFrame\n",
    "    Recon_X_num = Recon_X_num.cpu().numpy()\n",
    "    recon_df = transformer.inverse_transform(Recon_X_num, Recon_X_cat_labels)\n",
    "\n",
    "    return recon_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_dir = \"../csv/test/insurance\"\n",
    "csvs = os.listdir(df_dir)\n",
    "\n",
    "LR = 1e-3\n",
    "WD = 0\n",
    "D_TOKEN = 4\n",
    "TOKEN_BIAS = True\n",
    "\n",
    "N_HEAD = 1\n",
    "FACTOR = 32\n",
    "NUM_LAYERS = 2\n",
    "\n",
    "recon_report = {}\n",
    "\n",
    "\n",
    "for csv in csvs[:1]:\n",
    "    dataset_name = csv.replace(\".csv\", \"\")\n",
    "    full_path = os.path.join(df_dir, csv)\n",
    "\n",
    "    df = pd.read_csv(full_path)\n",
    "    train,test = train_test_split(df, random_state=42, test_size=0.2)\n",
    "    print(dataset_name,df.shape)\n",
    "\n",
    "    transformer = TabularTokenizerTransformer()\n",
    "    transformer.fit(train)\n",
    "    x_num, x_cat = transformer.transform(train)\n",
    "    #print(x_num, x_cat)\n",
    "    inv_df = transformer.inverse_transform(x_num, x_cat)\n",
    "    #print(inv_df)\n",
    "    \n",
    "    vae_model = train_vae_model(train, \"test\",num_layers=NUM_LAYERS, d_token=D_TOKEN, n_head=N_HEAD, factor=FACTOR, lr=LR, wd=WD,device=\"cuda:0\", num_epochs=1000)\n",
    "    \n",
    "    recon_df = reconstruct_dataframe(vae_model, test, transformer, threshold=0.5, device='cuda:0')\n",
    "    print(recon_df)\n",
    "\n",
    "    report = QualityReport()\n",
    "    meta_for_sdmetrics = {\"primary_key\":None, \"columns\":{}}\n",
    "    for c in df.columns:\n",
    "        dtype = \"numerical\" if pd.api.types.is_numeric_dtype(df[c]) else \"categorical\"\n",
    "        if c in recon_df.columns:\n",
    "            meta_for_sdmetrics['columns'][c] = {\"sdtype\":dtype}\n",
    "    report.generate(test, recon_df, meta_for_sdmetrics)\n",
    "\n",
    "    report_transposed = report.get_properties().set_index('Property').T\n",
    "\n",
    "    recon_report[dataset_name] = {}\n",
    "    recon_report[dataset_name]['shape'] = report.get_details(property_name='Column Shapes')\n",
    "    recon_report[dataset_name]['corr'] = report.get_details(property_name='Column Pair Trends')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(recon_report[dataset_name]['shape'])\n",
    "print(recon_report[dataset_name]['shape']['Score'].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(recon_report[dataset_name]['corr']['Score'].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from scipy.stats import wasserstein_distance, rankdata\n",
    "\n",
    "def uniform_quantile_normalization(data: pd.Series) -> pd.Series:\n",
    "    \"\"\"\n",
    "    Perform uniform quantile normalization on a pandas Series.\n",
    "    Transforms the data into a uniform distribution [0, 1].\n",
    "    \"\"\"\n",
    "    ranks = rankdata(data, method='average')  # Get the ranks of the data\n",
    "    normalized = (ranks - 1) / (len(data) - 1)  # Scale to [0, 1]\n",
    "    return pd.Series(normalized, index=data.index)\n",
    "\n",
    "def compute_wasserstein_distance(real: pd.DataFrame, synthetic: pd.DataFrame):\n",
    "    # Ensure both dataframes have the same columns\n",
    "    if not all(real.columns == synthetic.columns):\n",
    "        raise ValueError(\"Both dataframes must have the same column names.\")\n",
    "    \n",
    "    # Separate numerical columns\n",
    "    num_cols = real.select_dtypes(include=['number']).columns\n",
    "    \n",
    "    # Dictionary to store the results\n",
    "    wasserstein_distances = {}\n",
    "\n",
    "    # Compute Wasserstein distance for each numerical column\n",
    "    for col in num_cols:\n",
    "        # Apply uniform quantile normalization\n",
    "        real_normalized = uniform_quantile_normalization(real[col].dropna())\n",
    "        synthetic_normalized = uniform_quantile_normalization(synthetic[col].dropna())\n",
    "\n",
    "        # Compute the Wasserstein distance on the normalized data\n",
    "        dist = wasserstein_distance(real_normalized, synthetic_normalized)\n",
    "        wasserstein_distances[col] = dist\n",
    "        print(f\"Wasserstein distance for {col} (after quantile normalization): {dist:.4f}\")\n",
    "\n",
    "    return wasserstein_distances\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "recon_df = recon_df.loc[:, df.columns]\n",
    "recon_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_wasserstein_distance(df, recon_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "def compare_real_synthetic(real: pd.DataFrame, synthetic: pd.DataFrame):\n",
    "    # Ensure both dataframes have the same columns and number of rows\n",
    "    if not all(real.columns == synthetic.columns):\n",
    "        raise ValueError(\"Both dataframes must have the same column names.\")\n",
    "    if len(real) != len(synthetic):\n",
    "        raise ValueError(\"Both dataframes must have the same number of rows.\")\n",
    "    \n",
    "    # Ensure both dataframes have the same index\n",
    "    if not all(real.index == synthetic.index):\n",
    "        print(\"Realigning the indices of the DataFrames.\")\n",
    "        real = real.reset_index(drop=True)\n",
    "        synthetic = synthetic.reset_index(drop=True)\n",
    "    \n",
    "    # Separate numerical and categorical columns\n",
    "    num_cols = real.select_dtypes(include=['number']).columns\n",
    "    cat_cols = real.select_dtypes(include=['object', 'category']).columns\n",
    "    \n",
    "    # Plot overlapping density plots for numerical columns\n",
    "    for col in num_cols:\n",
    "        plt.figure(figsize=(8, 6))\n",
    "        sns.kdeplot(real[col], label='Real', shade=True)\n",
    "        sns.kdeplot(synthetic[col], label='Synthetic', shade=True)\n",
    "        plt.title(f'Density Plot for {col}')\n",
    "        plt.legend()\n",
    "        plt.show()\n",
    "\n",
    "    # Compute percentage of categories match for categorical columns\n",
    "    cat_match_percentages = {}\n",
    "    for col in cat_cols:\n",
    "        match_count = (real[col] == synthetic[col]).sum()\n",
    "        total_count = len(real[col])\n",
    "        match_percentage = (match_count / total_count) * 100\n",
    "        cat_match_percentages[col] = match_percentage\n",
    "    \n",
    "    # Display the categorical match percentages\n",
    "    for col, match_percentage in cat_match_percentages.items():\n",
    "        print(f'Category match for {col}: {match_percentage:.2f}%')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compare_real_synthetic(test, recon_df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
