{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "xoqe2lkk9Ooy"
   },
   "outputs": [],
   "source": [
    "import os, sys\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "from torch.nn import utils\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import math, collections, itertools, random\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from time import sleep\n",
    "from livelossplot import PlotLosses  \n",
    "\n",
    "from models.autoencoder import AE\n",
    "from models.matrix_factorization import PQ\n",
    "from utils.regularizers import FairnessLoss\n",
    "\n",
    "from preprocessing.synthetic import data_loader_synthetic\n",
    "from preprocessing.movielens import data_loader_movielens\n",
    "from preprocessing.lastfm import data_loader_lastfm\n",
    "\n",
    "from train import train_PQ, train_AE\n",
    "from utils.metrics import metrics\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "2iP-18jr9Oo3"
   },
   "outputs": [],
   "source": [
    "# 1. Choose data: synthetic, movielens, lastfm\n",
    "p0, p1 = 0.4, 0.1\n",
    "q0, q1 = 0.2, 0.2\n",
    "\n",
    "data_tuple = (data_loader_synthetic(p0, p1, q0, q1)) # return ((train_data, test_data), user attribute, item attribute)\n",
    "# data_tuple = (data_loader_movielens())\n",
    "# data_tuple = (data_loader_lastfm())\n",
    "\n",
    "# 2. Choose rating system: 'binary' or 'five-stars'\n",
    "data_type='binary'\n",
    "\n",
    "# 3. Choose model type: 'PQ' or 'AE'\n",
    "model_type='PQ'\n",
    "\n",
    "# 4. Choose algorithm: 'unfair', 'ours', 'VAL', 'UGF', 'CVS'\n",
    "algorithm_type = 'ours'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "z5kZA-2h9Oo5",
    "outputId": "b69362c8-d37d-4b7a-ec39-cbf0e54c2a9d"
   },
   "outputs": [],
   "source": [
    "# parameter settings\n",
    "learning_rate = 1e-3\n",
    "l_value = 0\n",
    "num_epochs = 1000\n",
    "lambda_ = 0.9\n",
    "tau=0\n",
    "n, m = data_tuple[0][0].shape[0], data_tuple[0][0].shape[1]\n",
    "r = 20\n",
    "device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "if algorithm_type == 'unfair':\n",
    "    f_criterion = None \n",
    "else: \n",
    "    f_criterion = FairnessLoss(h=0.01, tau=tau, delta=0.01, device=device, data_tuple=data_tuple, type_=algorithm_type)\n",
    "\n",
    "# train the model\n",
    "if model_type == 'PQ':\n",
    "    model = PQ(data_type, n, m, 20).to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate)\n",
    "    logs = train_PQ(data_tuple, model, optimizer, num_epochs, device, \n",
    "                    l_value=l_value, lambda_=lambda_, f_criterion=f_criterion, tau=tau)\n",
    "elif model_type == 'AE':\n",
    "    model = AE(data_type, m).to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr = learning_rate)\n",
    "    logs = train_AE(data_tuple, model, optimizer, num_epochs, device, \n",
    "                    l_value=l_value, lambda_=lambda_, f_criterion=f_criterion, tau=tau)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZOxHDdeC9Oo8",
    "outputId": "5adaace3-f314-4703-c889-138940a6db46"
   },
   "outputs": [],
   "source": [
    "# check the results\n",
    "metrics(model,data_tuple,device, model_type, tau)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "experiments.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
