{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "from src.dataset_loaders import load_vectors, get_samplers\n",
    "from src.utils import get_pca_models\n",
    "from src import utils\n",
    "from src.train import train_discrete\n",
    "import wandb\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import yaml\n",
    "import numpy as np\n",
    "import random\n",
    "import pickle\n",
    "import gc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Possible ```DATASET_NAME``` values are: ```twitter```, ```wiki-gigaword```, ```bone_marrow``` and ```muse```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "DATASET_NAME = 'twitter'\n",
    "\n",
    "SOURCE_DIM   = 100 \n",
    "TARGET_DIM   = 50\n",
    "\n",
    "EMB_TYPE_SOURCE = 'BP'\n",
    "EMB_TYPE_TARGET = 'BP'\n",
    "\n",
    "MAX_ITERS    = 200\n",
    "\n",
    "if EMB_TYPE_SOURCE == 'BP' or EMB_TYPE_TARGET == 'BP':\n",
    "    VS = 200000\n",
    "    if DATASET_NAME == 'muse':\n",
    "        SOURCE_LANG = 'en'\n",
    "        TARGET_LANG = 'en'\n",
    "    else:\n",
    "        SOURCE_LANG = None\n",
    "        TARGET_LANG = None\n",
    "else:\n",
    "    VS = None\n",
    "    SOURCE_LANG = None\n",
    "    TARGET_LANG = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "METHOD_NAME  = 'StructuredGW'\n",
    "DEVICE       = 'cpu'\n",
    "\n",
    "ALPHA_INIT    = 1.0\n",
    "SEED_INIT     = 43\n",
    "COST_DISCRETE = 'inner'\n",
    "N_REPEATS     = 5\n",
    "\n",
    "model_specific_configs = {'RegGW': dict(HIDDEN_SIZES_MLP = [512, 256, 256],\n",
    "                                        EPS_FIT          = 0.01,\n",
    "                                        EPS_REG          = 0.001,\n",
    "                                        LAMBDA           = 1,\n",
    "                                        MOVER_LR         = 1e-4),\n",
    "                          \n",
    "                          'FlowGW': dict(HIDDEN_SIZES_MLP = [1024, 1024, 1024, 1024],\n",
    "                                         EPS              = 1e-4,\n",
    "                                         N_FREQ           = 128,\n",
    "                                         MOVER_LR         = 1e-4),\n",
    "                          \n",
    "                          'StructuredGW': dict(EPS = 1e-3),\n",
    "                      \n",
    "                          'AlignGW': dict(EPS = 1e-3) \n",
    "                          }\n",
    "\n",
    "config = {'dataset': dict(DATASET_NAME     = DATASET_NAME,\n",
    "                          DEVICE           = DEVICE,\n",
    "                          SOURCE_DIM       = SOURCE_DIM,\n",
    "                          TARGET_DIM       = TARGET_DIM,\n",
    "                          EMB_TYPE_SOURCE  = EMB_TYPE_SOURCE,\n",
    "                          EMB_TYPE_TARGET  = EMB_TYPE_TARGET,\n",
    "                          VS               = VS,\n",
    "                          SOURCE_LANG      = SOURCE_LANG,\n",
    "                          TARGET_LANG      = TARGET_LANG,\n",
    "                          \n",
    "                          N_MAX_SAMPLES    = 90000, \n",
    "                          N_TRAIN_SAMPLES  = 6000, \n",
    "                          N_TEST_SAMPLES   = 2048,\n",
    "                          N_EVAL           = 1,\n",
    "                          ALPHA            = ALPHA_INIT, \n",
    "                          SEED             = SEED_INIT,\n",
    "                          NORMALIZE_VECS   = False,\n",
    "                          SHUFFLE          = False\n",
    "    ),\n",
    "\n",
    "    'training': dict(TRAIN_TYPE      = 'discrete',\n",
    "                     METHOD_NAME     = METHOD_NAME,\n",
    "                     MAX_ITERS       = MAX_ITERS,\n",
    "                     COST_DISCRETE   = COST_DISCRETE,\n",
    "                     N_REPEATS       = N_REPEATS\n",
    "    ),\n",
    "\n",
    "    'model_specific': model_specific_configs.get(METHOD_NAME, {})\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Loading dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading twitter_BP_100 to source...\n",
      "Loading twitter_BP_50 to target...\n",
      "Original: 92337\n",
      "Final: 90000\n",
      "torch.Size([90000, 100])\n",
      "torch.Size([90000, 50])\n"
     ]
    }
   ],
   "source": [
    "dataset_path = '../datasets'\n",
    "sys.path.append(dataset_path)\n",
    "\n",
    "source_vectors, target_vectors = load_vectors(dataset_path, config)\n",
    "\n",
    "print(source_vectors.shape)\n",
    "print(target_vectors.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'\n",
    "os.environ['JAX_PLATFORM_NAME'] = 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "wandb_report = True\n",
    "if wandb_report is False:\n",
    "    wandb_mode = 'disabled'\n",
    "else:\n",
    "    wandb_mode = 'online'\n",
    "\n",
    "N            = config['dataset']['N_MAX_SAMPLES']//1000\n",
    "TRAIN_TYPE   = config['training']['TRAIN_TYPE']\n",
    "project_name = f'NIPS_FINAL2_{METHOD_NAME}_{DATASET_NAME}_{SOURCE_DIM}({EMB_TYPE_SOURCE})->{TARGET_DIM}({EMB_TYPE_TARGET})_6K_{N_REPEATS}reps_fixed_marginals'\n",
    "\n",
    "\n",
    "metrics_names = ['Top@1', 'Top@5', 'Top@10', 'Top@1_red', 'Top@5_red', 'Top@10_red', 'cossim_gt', 'foscttm_norm', 'foscttm_unnorm',\n",
    "                 'distortion', 'mmd', 'bw_uvp', 'sinkhorn_divergence']\n",
    "\n",
    "_, _, _, _, test_sampler = get_samplers(config, source_vectors, target_vectors, 0)      \n",
    "\n",
    "alpha_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0][::-1]\n",
    "\n",
    "\n",
    "metrics_out = {str(np.round(alpha, 1)):[] for alpha in alpha_values}\n",
    "\n",
    "for ALPHA in alpha_values:\n",
    "    \n",
    "    config['dataset']['ALPHA'] = ALPHA \n",
    "    \n",
    "        \n",
    "    print('================================')\n",
    "    print(f'Experiment for ALPHA={ALPHA}')\n",
    "    print('================================')\n",
    "    \n",
    "    for ix in range(N_REPEATS):\n",
    "        \n",
    "        exp_name = f'ALPHA_{np.round(ALPHA, 1)}_repeat_{ix}'\n",
    "        wandb.init(name=exp_name, config=config, project=project_name, mode=wandb_mode)\n",
    "            \n",
    "        SEED = random.randint(0, 10000)\n",
    "        config['dataset']['SEED'] = SEED\n",
    "        print('Seed: ', SEED)\n",
    "        \n",
    "        source_vectors_red, target_vectors_red, train_source_sampler, train_target_sampler, _ = get_samplers(config, source_vectors, target_vectors, ix) \n",
    "        \n",
    "        trained_class, metrics_dict = train_discrete(train_source_sampler, train_target_sampler,\n",
    "                                                     test_sampler, \n",
    "                                                     metrics_names, target_vectors, target_vectors_red,\n",
    "                                                     config,\n",
    "                                                     wandb_report=wandb_report,\n",
    "                                                     axis_lims=None, report_every=100)\n",
    "        \n",
    "        metrics_out[str(np.round(ALPHA, 1))].append(metrics_dict)\n",
    "\n",
    "        with open(f'results_{TRAIN_TYPE}_NIPS/{project_name}_10_00.pkl', 'wb') as f:\n",
    "            pickle.dump(metrics_out, f)\n",
    "\n",
    "        gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
