{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "831de653-07c4-4675-8d74-67157f8f82da",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import cond as cd\n",
    "import DP_Sliding as dp\n",
    "import process_edited as pce\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39e40d20-d38a-4456-8660-d41d28d27021",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = 'C:/Users/namjo/OneDrive/Desktop/TimeAutoDiff/Dataset/Metro_Traffic'\n",
    "filename = f'{data}.csv'\n",
    "real_df = pd.read_csv(filename)\n",
    "\n",
    "real_df = real_df.iloc[0:2000,:] # You can load entire data if your memory is sufficient! \n",
    "real_df1 = real_df.drop(['date'], axis=1)\n",
    "parser = pce.DataFrameParser().fit(real_df1, 1)\n",
    "name = parser.column_name()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67ecc28e-cfef-4e51-8247-c65f5ddbb815",
   "metadata": {},
   "outputs": [],
   "source": [
    "response_list = ['traffic_volume']; VAE_training = 10000; diff_training = 10000;\n",
    "\n",
    "processed_data, time_info, train_idx, test_idx = dp.single_split_train_test(real_df, 48)\n",
    "response_df, con_df, response_train, response_test, cond_train, cond_test, time_info_train, time_info_test = dp.divide_respons_cond_time(real_df, processed_data, time_info, response_list, train_idx, test_idx)\n",
    "model = cd.C_TimeAutoDiff(response_df, con_df, response_train, cond_train, time_info_train, VAE_training, diff_training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "575db28d-c8b7-431f-ab26-ecb0f08d8c96",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'; diffusion_steps = 100; Batch_size, Seq_len, _ = response_test.shape; \n",
    "_synth_data = cd.cond_sampling(response_df, model[0], model[1], cond_test, time_info_test, Batch_size, Seq_len, Lat_dim = 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80a4d27c-d685-420f-91ef-b3626953ccfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "label = 89 # 10/6/2012 - Saturday\n",
    "##################################################################################################################\n",
    "# plot label's time series snapshot\n",
    "fig, axes = plt.subplots(1, 2, figsize=(20, 4)) \n",
    "column_name = response_list; \n",
    "\n",
    "# Plot histogram for the specified column\n",
    "for i in range(response_test.shape[2]):\n",
    "    axes[i].plot(response_test[label,:,i],  marker='o', linestyle='-', color='b', label=f'Real {column_name[i]}')\n",
    "    axes[i].plot(_synth_data[label,:,i],  marker='o', linestyle='-', color='r', label=f'Conditionally Generated {column_name[i]}')\n",
    "    axes[i].set_xlabel('Time')\n",
    "    axes[i].set_ylabel(f'{column_name[i]}')\n",
    "    axes[i].grid(True)\n",
    "    axes[i].legend(loc='upper left')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
