{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TorchSpatial Tutorial\n",
    "\n",
    "- [I. Environment Settings](#i-environment-settings)\n",
    "- [II. Data Download](#ii-data-download)\n",
    "- [III. Example of Initializing a Location Encoder](#iii-example-of-initializing-a-location-encoder)\n",
    "- [IV. Experiments on Benchmark Datasets](#iv-experiments-on-benchmark-datasets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## I. Environment Settings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To set up the environment for the TorchSpatial project, follow these steps:\n",
    "\n",
    "1. **Install Conda (if not already installed)**  \n",
    "   Ensure you have [Anaconda](https://www.anaconda.com/products/distribution) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html) installed on your system. You can download the installer from the respective links.\n",
    "\n",
    "2. **Create the Conda Environment**  \n",
    "   Open your terminal (or Anaconda Prompt on Windows) and navigate to the directory containing the `environment.yml` file. Use the following command to create the environment:\n",
    "   ```bash\n",
    "   conda env create -f environment.yml\n",
    "   ```\n",
    "   Alternatively, you can create an environment directly by following the `requirements.txt`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## II. Data Download"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The datasets can be downloaded from the following DOI link: [Download Data](https://figshare.com/articles/dataset/LocBench/26026798).\n",
    "\n",
    "1. **Download the dataset from the specified URL**  \n",
    "   For example, to download the Birdsnap dataset, use the following command in the terminal:\n",
    "   ```bash\n",
    "   wget https://figshare.com/ndownloader/files/47020978\n",
    "2. **Extract the contents of the tar file and move the files to the desired folder if necessary**  \n",
    "   ```bash\n",
    "   tar -xvf 47020978\n",
    "3. **Update the path for dataset**  \n",
    "   Please navigate to `main/paths.py` and ensure that the dataset path matches the corresponding entry in `main/paths.py`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## III. Example of Initializing a Location Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from SpatialRelationEncoder import *\n",
    "from module import *\n",
    "from data_utils import *\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify the params for the location encoder, current spa_embed_dim is 128\n",
    "params = {\n",
    "    # The type of location encoder you will use, FIXED\n",
    "    'spa_enc_type': 'Space2Vec-grid',\n",
    "    'spa_embed_dim': 128,  # The dimension of the location embedding,TUNE\n",
    "    'extent': (0, 200, 0, 200),  # Extent of the coords, FIXED\n",
    "    'freq': 16,  # The number of scales, TUNE (See Equation(3) in (Mai, 2020))\n",
    "    # Lambda_max, maximum scale, FIXED (See Equation(4) in (Mai, 2020))\n",
    "    'max_radius': 1,\n",
    "    # Lambda_min, minimum scale, TUNE (See Equation(4) in (Mai, 2020))\n",
    "    'min_radius': 0.0001,\n",
    "    'spa_f_act': \"leakyrelu\",  # Activation function, FIXED\n",
    "    'freq_init': 'geometric',  # The method to make the Fourier frequency, FIXED\n",
    "    'num_hidden_layer': 1,  # The number of hidden layer, TUNE\n",
    "    'dropout': 0.5,  # Dropout rate, TUNE\n",
    "    'hidden_dim': 512,  # Hidden embedding dimension, TUNE\n",
    "    'use_layn': True,  # whether to you layer normalization, FIXED\n",
    "    'skip_connection': True,  # Whether to use skip connection, FIXED\n",
    "    'spa_enc_use_postmat': True,  # FIXED\n",
    "    'device': 'cpu'  # The device, ‘cpu’ or ‘cuda:0’, etc\n",
    "}\n",
    "\n",
    "loc_enc = get_spa_encoder(\n",
    "    train_locs=[],\n",
    "    params=params,\n",
    "    spa_enc_type=params['spa_enc_type'],\n",
    "    spa_embed_dim=params['spa_embed_dim'],\n",
    "    extent=params['extent'],\n",
    "    coord_dim=2,\n",
    "    frequency_num=params['freq'],\n",
    "    max_radius=params['max_radius'],\n",
    "    min_radius=params['min_radius'],\n",
    "    f_act=params['spa_f_act'],\n",
    "    freq_init=params['freq_init'],\n",
    "    use_postmat=params['spa_enc_use_postmat'],\n",
    "    device=params['device']).to(params['device'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  8,  18],\n",
       "       [187, 170],\n",
       "       [ 62,  87],\n",
       "       [169, 135],\n",
       "       [ 97, 171],\n",
       "       [145, 101],\n",
       "       [152,   2],\n",
       "       [ 54, 105],\n",
       "       [ 27,  41],\n",
       "       [ 53,  54]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# synthetic coords data\n",
    "batch_size, coord_dim = 10, 2\n",
    "coords = np.random.randint(1, 201, size=(batch_size, coord_dim))\n",
    "\n",
    "# coords: shape [batch_size, 2]\n",
    "coords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000, -0.1178,  2.8982,  ..., -0.0000, -0.8643, -0.0000],\n",
       "        [-0.6335,  3.9401, -0.1604,  ...,  1.5705, -0.0699,  0.0000],\n",
       "        [ 0.0000,  3.4673,  0.0000,  ...,  0.0000,  0.7862, -0.0000],\n",
       "        ...,\n",
       "        [ 0.0751, -0.0000, -0.0000,  ...,  0.0000, -0.2110, -0.0000],\n",
       "        [ 0.0000,  3.2486, -0.3153,  ..., -0.0000,  1.8864,  0.0000],\n",
       "        [ 0.0000,  1.2793, -0.6869,  ..., -0.4061, -0.7918, -0.4826]],\n",
       "       grad_fn=<SqueezeBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coords = np.array(coords)\n",
    "coords = np.expand_dims(coords, axis=1)\n",
    "loc_embeds = torch.squeeze(loc_enc(coords))\n",
    "\n",
    "# loc_embed: shape [batch_size, spa_embed_dim]\n",
    "loc_embeds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IV. Experiments on Benchmark Datasets "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Model Evaluation\n",
    "\n",
    "The **main script** for model evaluation and experiments is located at: `TorchSpatial/main/main.py`. Below are the **key command-line arguments** with their descriptions and default values.\n",
    "\n",
    "##### 🛠️ General Options\n",
    "```python\n",
    "    parser.add_argument(\"--save_results\", type=str, default=\"T\", \n",
    "        help=\"Save the results (lon, lat, rr, acc1, acc3) to a CSV file for final evaluation.\"\n",
    "    )\n",
    "    parser.add_argument(\"--device\", type=str, \n",
    "        default=torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\"),\n",
    "        help=\"Device to use: 'cuda' for GPU or 'cpu' for CPU.\"\n",
    "    )\n",
    "    parser.add_argument(\"--model_dir\", type=str, default=\"../models/\",\n",
    "        help=\"Directory where models are stored.\"\n",
    "    )\n",
    "    parser.add_argument(\"--num_epochs\", type=int, default=20,\n",
    "        help=\"Number of training epochs.\"\n",
    "    )\n",
    "    parser.add_argument(\"--load_super_model\", type=str, default=\"F\", \n",
    "        help=\"Load a pretrained supervised model (T/F).\"\n",
    "    )\n",
    "```\n",
    "#### 📂 Dataset Options\n",
    "```python\n",
    "    parser.add_argument(\"--dataset\", type=str, default=\"birdsnap\", \n",
    "        choices=[\n",
    "            \"inat_2021\", \"inat_2018\", \"inat_2017\", \"birdsnap\", \"nabirds\", \n",
    "            \"yfcc\", \"fmow\", \"sustainbench_asset_index\", \n",
    "            \"sustainbench_under5_mort\", \"sustainbench_water_index\", \n",
    "            \"sustainbench_women_bmi\", \"sustainbench_women_edu\", \n",
    "            \"sustainbench_sanitation_index\", \"mosaiks_population\", \n",
    "            \"mosaiks_elevation\", \"mosaiks_forest_cover\", \n",
    "            \"mosaiks_nightlights\"\n",
    "        ],\n",
    "        help=\"Dataset to use for the experiment.\"\n",
    "    )\n",
    "    parser.add_argument(\"--train_sample_ratio\", type=float, default=0.01,\n",
    "        help=\"Training dataset sample ratio for supervised learning.\"\n",
    "    )\n",
    "    parser.add_argument(\"--train_sample_method\", type=str, default=\"random-fix\",\n",
    "        help=\"\"\"Training dataset sampling method:\n",
    "        - 'stratified-fix': Stratified sampling with fixed indices.\n",
    "        - 'stratified-random': Stratified sampling with random indices.\n",
    "        - 'random-fix': Random sampling with fixed indices.\n",
    "        - 'random-random': Random sampling with random indices.\n",
    "        - 'ssi-sample': Sample based on spatial self-information.\n",
    "        \"\"\"\n",
    "    )\n",
    "```\n",
    "#### ⚙️ Training Hyperparameters\n",
    "```python\n",
    "    parser.add_argument(\"--lr\", type=float, default=0.001, \n",
    "        help=\"Learning rate.\"\n",
    "    )\n",
    "    parser.add_argument(\"--lr_decay\", type=float, default=0.98, \n",
    "        help=\"Learning rate decay factor.\"\n",
    "    )\n",
    "    parser.add_argument(\"--weight_decay\", type=float, default=0.0, \n",
    "        help=\"Weight decay (L2 regularization).\"\n",
    "    )\n",
    "    parser.add_argument(\"--dropout\", type=float, default=0.5, \n",
    "        help=\"Dropout rate used in the feedforward neural network.\"\n",
    "    )\n",
    "    parser.add_argument(\"--batch_size\", type=int, default=1024, \n",
    "        help=\"Batch size for training.\"\n",
    "    )\n",
    "\n",
    "```\n",
    "#### 🔍 Logging and Evaluation\n",
    "```python\n",
    "    parser.add_argument(\"--log_frequency\", type=int, default=50, \n",
    "        help=\"Frequency of logging (in batches).\"\n",
    "    )\n",
    "    parser.add_argument(\"--max_num_exs_per_class\", type=int, default=100, \n",
    "        help=\"Maximum number of examples per class.\"\n",
    "    )\n",
    "    parser.add_argument(\"--balanced_train_loader\", type=str, default=\"T\", \n",
    "        help=\"Use a balanced train loader (T/F).\"\n",
    "    )\n",
    "    parser.add_argument(\"--eval_frequency\", type=int, default=100, \n",
    "        help=\"Frequency of model evaluation (in batches).\"\n",
    "    )\n",
    "    parser.add_argument(\"--unsuper_save_frequency\", type=int, default=5, \n",
    "        help=\"Frequency of saving unsupervised models (in epochs).\"\n",
    "    )\n",
    "    parser.add_argument(\"--do_epoch_save\", type=str, default=\"F\", \n",
    "        help=\"Save the model at each epoch (T/F).\"\n",
    "    )\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run the model, please navigate to the `TorchSpatial/main/` folder in the terminal and execute the command, for example:\n",
    "\n",
    "```bash\n",
    "python3 main.py \\\n",
    "    --save_results T \\\n",
    "    --load_super_model F \\\n",
    "    --spa_enc_type Sphere2Vec-sphereC \\\n",
    "    --meta_type ebird_meta \\\n",
    "    --dataset birdsnap \\\n",
    "    --eval_split test \\\n",
    "    --frequency_num 64 \\\n",
    "    --max_radius 1 \\\n",
    "    --min_radius 0.001 \\\n",
    "    --num_hidden_layer 1 \\\n",
    "    --hidden_dim 512 \\\n",
    "    --spa_f_act relu \\\n",
    "    --unsuper_lr 0.1 \\\n",
    "    --lr 0.001 \\\n",
    "    --model_dir ../models/sphere2vec_sphereC/ \\\n",
    "    --num_epochs 100 \\\n",
    "    --train_sample_ratio 1.0 \\\n",
    "    --device cpu\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Model Hyperparameter Tuning\n",
    "\n",
    "To tune the model hyperparameters, sample bash files are provided in the `TorchSpatial/main/run_sh` directory.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Model Fine Tuning\n",
    "\n",
    "We have also provided several pre-trained models in the `TorchSpatial/pre_trained_models` directory. If you would like to fine-tune one of these models, please set the corresponding model path and load the model. Example Bash files can be found in `TorchSpatial/main/eva_sh`. For example:\n",
    "\n",
    "```bash\n",
    "python3 main.py \\\n",
    "    --save_results T \\\n",
    "    --load_super_model T \\\n",
    "    --spa_enc_type Sphere2Vec-sphereC \\\n",
    "    --meta_type ebird_meta \\\n",
    "    --dataset birdsnap \\\n",
    "    --eval_split test \\\n",
    "    --frequency_num 64 \\\n",
    "    --max_radius 1 \\\n",
    "    --min_radius 0.001 \\\n",
    "    --num_hidden_layer 1 \\\n",
    "    --hidden_dim 512 \\\n",
    "    --spa_f_act relu \\\n",
    "    --unsuper_lr 0.1 \\\n",
    "    --lr 0.001 \\\n",
    "    --model_dir ../models/sphere2vec_sphereC/ \\\n",
    "    --train_sample_ratio 1.0 \\\n",
    "    --device cpu"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Experiment Outputs\n",
    "\n",
    "#### a. 🗒️ Log File  \n",
    "- **Description**: A log file recording the training progress and final results.  \n",
    "- **Example Path**:  \n",
    "  `TorchSpatial/pre_trained_models/sphere2vec_sphereC/model_birdsnap_ebird_meta_Sphere2Vec-sphereC_inception_v3_0.0010_64_0.0010000_1_512.log`\n",
    "\n",
    "  \n",
    "#### b. 💾 Model Checkpoint  \n",
    "- **Description**: A corresponding saved model checkpoint with the same name as the log file but with a `.pth.tar` extension.  \n",
    "- **Example Path**:  \n",
    "  `TorchSpatial/pre_trained_models/sphere2vec_sphereC/model_birdsnap_ebird_meta_Sphere2Vec-sphereC_inception_v3_0.0010_64_0.0010000_1_512.pth.tar`\n",
    "\n",
    "#### c. 📊 Evaluation Table\n",
    "- **Description** if `--save_results` is set to `T`, an evaluated table will be generated and saved.\n",
    "- **Example Path**  \n",
    "  `TorchSpatial/eval_results/classification/eval_birdsnap_ebird_meta_test_Sphere2Vec-sphereC.csv`\n",
    "- **Table Contents**  \n",
    "  - Each row represents a data point.  \n",
    "  - Columns include the location and predicted performance metrics.\n",
    "- **Example Columns**  \n",
    "  - Classification Tasks:  \n",
    "    - `lon`, `lat`, `true_class_prob`, `reciprocal_rank`, `hit@1`, `hit@3`  \n",
    "  - Regression Tasks:  \n",
    "    - `lon`, `lat`, `predictions`, `labels`, `relative_error`"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
