{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run the demo\n",
    "\n",
    "FuxiCTR version: v1.0\n",
    "\n",
    "We provide [multiple demo scripts](https://github.com/xue-pai/FuxiCTR/tree/main/demo) to run a given model on the tiny dataset. Please follow these examples to get started. The code workflow is structured as follows:\n",
    "\n",
    "```python\n",
    "# Set data params and model params\n",
    "params = {...}\n",
    "\n",
    "# Set the feature encoding specs\n",
    "feature_encoder = FeatureEncoder(feature_cols, label_col, ...) # define the feature encoder\n",
    "feature_encoder.fit(...) # fit and transfrom the data\n",
    "\n",
    "# Load data generators\n",
    "train_gen, valid_gen, test_gen = data_generator(feature_encoder, ...)\n",
    "\n",
    "# Define a model\n",
    "model = DeepFM(...)\n",
    "\n",
    "# Train the model\n",
    "model.fit_generator(train_gen, validation_data=valid_gen, ...)\n",
    "\n",
    "# Evaluation\n",
    "model.evaluate_generator(test_gen)\n",
    "\n",
    "```\n",
    "   .\n",
    "\n",
    "In the following, we show the demo `DeepFM_demo.py`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "from fuxictr.datasets import data_generator\n",
    "from fuxictr.datasets.taobao import FeatureEncoder\n",
    "from datetime import datetime\n",
    "from fuxictr.utils import set_logger, print_to_json\n",
    "import logging\n",
    "from fuxictr.pytorch.models import DeepFM\n",
    "from fuxictr.pytorch.utils import seed_everything"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After importing the required packages, one needs to define the params dict for DeepFM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_cols = [{'name': [\"userid\",\"adgroup_id\",\"pid\",\"cate_id\",\"campaign_id\",\"customer\",\"brand\",\"cms_segid\",\n",
    "                          \"cms_group_id\",\"final_gender_code\",\"age_level\",\"pvalue_level\",\"shopping_level\",\"occupation\"],\n",
    "                 'active': True, 'dtype': 'str', 'type': 'categorical'}]\n",
    "label_col = {'name': 'clk', 'dtype': float}\n",
    "\n",
    "params = {'model_id': 'DeepFM_demo',\n",
    "          'dataset_id': 'tiny_data_demo',\n",
    "          'train_data': '../data/tiny_data/train_sample.csv',\n",
    "          'valid_data': '../data/tiny_data/valid_sample.csv',\n",
    "          'test_data': '../data/tiny_data/test_sample.csv',\n",
    "          'model_root': '../checkpoints/',\n",
    "          'data_root': '../data/',\n",
    "          'feature_cols': feature_cols,\n",
    "          'label_col': label_col,\n",
    "          'embedding_regularizer': 0,\n",
    "          'net_regularizer': 0,\n",
    "          'hidden_units': [64, 64],\n",
    "          'hidden_activations': \"relu\",\n",
    "          'learning_rate': 1e-3,\n",
    "          'net_dropout': 0,\n",
    "          'batch_norm': False,\n",
    "          'optimizer': 'adam',\n",
    "          'task': 'binary_classification',\n",
    "          'loss': 'binary_crossentropy',\n",
    "          'metrics': ['logloss', 'AUC'],\n",
    "          'min_categr_count': 1,\n",
    "          'embedding_dim': 10,\n",
    "          'batch_size': 16,\n",
    "          'epochs': 3,\n",
    "          'shuffle': True,\n",
    "          'seed': 2019,\n",
    "          'monitor': 'AUC',\n",
    "          'monitor_mode': 'max',\n",
    "          'use_hdf5': True,\n",
    "          'pickle_feature_encoder': True,\n",
    "          'save_best_only': True,\n",
    "          'every_x_epochs': 1,\n",
    "          'patience': 2,\n",
    "          'workers': 1,\n",
    "          'verbose': 0,\n",
    "          'version': 'pytorch',\n",
    "          'gpu': -1}\n",
    "\n",
    "# Set the logger and random seed\n",
    "set_logger(params)\n",
    "logging.info('Start the demo...')\n",
    "logging.info(print_to_json(params))\n",
    "seed_everything(seed=params['seed'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then set the FeatureEncoder to fit the training data and encode the raw features (e.g., normalizing continious values and mapping/reindex categorical features) from csv files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_encoder = FeatureEncoder(feature_cols, \n",
    "                                 label_col, \n",
    "                                 dataset_id=params['dataset_id'], \n",
    "                                 data_root=params[\"data_root\"],\n",
    "                                 version=params['version'])\n",
    "feature_encoder.fit(train_data=params['train_data'], \n",
    "                    min_categr_count=params['min_categr_count'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preprocess the csv files to h5 files and get the data generators ready for train/validation/test. Note that the h5 files can be reused for subsequent experiments directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_gen, valid_gen, test_gen = data_generator(feature_encoder,\n",
    "                                                train_data=params['train_data'],\n",
    "                                                valid_data=params['valid_data'],\n",
    "                                                test_data=params['test_data'],\n",
    "                                                batch_size=params['batch_size'],\n",
    "                                                shuffle=params['shuffle'],\n",
    "                                                use_hdf5=params['use_hdf5'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize a DeepFM model and fit the model with the training and validation data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DeepFM(feature_encoder.feature_map, **params)\n",
    "model.fit_generator(train_gen, validation_data=valid_gen, epochs=params['epochs'],\n",
    "                    verbose=params['verbose'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reload the saved best model checkpoint for testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.info('***** validation/test results *****')\n",
    "model.load_weights(model.checkpoint)\n",
    "model.evaluate_generator(valid_gen)\n",
    "model.evaluate_generator(test_gen)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
