{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "sys.path.append(\"../..\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "from src.dataset_loaders import load_vectors, get_samplers_paired, get_samplers\n",
    "from src.utils import get_pca_models\n",
    "from src import utils\n",
    "from src.train import train_continuous\n",
    "import wandb\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import yaml\n",
    "import numpy as np\n",
    "import pickle\n",
    "import random\n",
    "import jax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Possible ```DATASET_NAME``` values are: ```twitter```, ```wiki-gigaword```, ```bone_marrow```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_NAME = 'twitter'\n",
    "\n",
    "SOURCE_DIM   = 50  \n",
    "TARGET_DIM   = 25\n",
    "\n",
    "EMB_TYPE_SOURCE = 'glove'\n",
    "EMB_TYPE_TARGET = 'glove'\n",
    "\n",
    "\n",
    "if EMB_TYPE_SOURCE == 'BP' or EMB_TYPE_TARGET == 'BP':\n",
    "    VS = 200000\n",
    "    if DATASET_NAME == 'muse':\n",
    "        SOURCE_LANG = 'en'\n",
    "        TARGET_LANG = 'en'\n",
    "        N_MAX_SAMPLES = 90000\n",
    "    else:\n",
    "        SOURCE_LANG = None\n",
    "        TARGET_LANG = None\n",
    "        N_MAX_SAMPLES = 90000\n",
    "else:\n",
    "    VS = None\n",
    "    SOURCE_LANG = None\n",
    "    TARGET_LANG = None\n",
    "    N_MAX_SAMPLES = 400000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "METHOD_NAME  = 'NeuralGW'\n",
    "DEVICE       = 'cuda:0'\n",
    "\n",
    "ALPHA_INIT    = 1.0\n",
    "SEED_INIT     = 43\n",
    "COST_DISCRETE = 'inner'\n",
    "\n",
    "N_REPEATS     = 5\n",
    "\n",
    "model_specific_configs = {'CycleGW': dict(HIDDEN_SIZES_MLP = [512]*1,\n",
    "                                          EPS              = 5e-4, #5e-3\n",
    "                                          F_LR             = 1e-3, #1e-3\n",
    "                                          G_LR             = 1e-3, #1e-3\n",
    "                                          REG              = 0.1,  #0.1\n",
    "                                          SIGMAS           = None,\n",
    "                                          TAKE_MEDIAN      = True,\n",
    "                                          KERNEL_TYPE      = 'sinkhorn',\n",
    "                                          N_EPOCHS         = 50\n",
    "                                          ),\n",
    "                          \n",
    "                          'RegGW': dict(HIDDEN_SIZES_MLP= [512, 256, 256],\n",
    "                                        EPS_FIT         = 0.01,\n",
    "                                        EPS_REG         = 0.001,\n",
    "                                        LAMBDA          = 1,\n",
    "                                        MOVER_LR        = 1e-4,\n",
    "                                        N_EPOCHS        = 8000),\n",
    "                          \n",
    "                          'NeuralGW': dict(COST_ITERS      = 1,\n",
    "                                           MOVER_ITERS     = 10,\n",
    "                                           CRITIC_ITERS    = 1,\n",
    "                                           REG_CRITIC      = 0.1,\n",
    "                                           COST_LR         = 1e-4,\n",
    "                                           MOVER_LR        = 1e-4, \n",
    "                                           CRITIC_LR       = 1e-4,\n",
    "                                           HIDDEN_SIZES_MLP= [512, 512, 512, 512],\n",
    "                                           N_EPOCHS        = 100),\n",
    "                      \n",
    "                          'FlowGW': dict(HIDDEN_SIZES_MLP = [1024, 1024, 1024, 1024],\n",
    "                                         EPS              = 1e-3,#1e-4,\n",
    "                                         N_FREQ           = 128,\n",
    "                                         MOVER_LR         = 1e-4,\n",
    "                                         N_EPOCHS         = 2500)\n",
    "                          }\n",
    "\n",
    "\n",
    "config = {'dataset':dict(DATASET_NAME     = DATASET_NAME,\n",
    "                         DEVICE           = DEVICE,\n",
    "                         SOURCE_DIM       = SOURCE_DIM,\n",
    "                         TARGET_DIM       = TARGET_DIM,\n",
    "                         EMB_TYPE_SOURCE  = EMB_TYPE_SOURCE,\n",
    "                         EMB_TYPE_TARGET  = EMB_TYPE_TARGET,\n",
    "                         SOURCE_LANG      = SOURCE_LANG,\n",
    "                         TARGET_LANG      = TARGET_LANG,\n",
    "                         VS               = VS,\n",
    "                         \n",
    "                         N_MAX_SAMPLES    = N_MAX_SAMPLES, \n",
    "                         N_TRAIN_SAMPLES  = 380000, \n",
    "                         N_TEST_SAMPLES   = 2048,\n",
    "                         N_EVAL           = 1,\n",
    "                         ALPHA            = ALPHA_INIT, \n",
    "                      \n",
    "                         BATCH_SIZE_TRAIN = 500,\n",
    "                         BATCH_SIZE_TEST  = 2048,\n",
    "                         SEED             = SEED_INIT,\n",
    "                         NORMALIZE_VECS   = False,\n",
    "                         SHUFFLE          = True\n",
    "                          ),\n",
    "          \n",
    "          'training':dict(TRAIN_TYPE      = 'continuous',\n",
    "                          METHOD_NAME     = METHOD_NAME,\n",
    "                          N_EPOCHS        = 100,\n",
    "                          COST_DISCRETE   = COST_DISCRETE,\n",
    "                          N_REPEATS       = N_REPEATS\n",
    "                          ),\n",
    "          \n",
    "          'model_specific': model_specific_configs.get(METHOD_NAME, {})\n",
    "\n",
    "          }\n",
    "          \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Loading dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading twitter_glove_50 to source...\n",
      "Loading twitter_glove_25 to target...\n",
      "Original: 1193514\n",
      "Final: 400000\n",
      "torch.Size([400000, 50])\n",
      "torch.Size([400000, 25])\n"
     ]
    }
   ],
   "source": [
    "dataset_path = '../../datasets'\n",
    "sys.path.append(dataset_path)\n",
    "\n",
    "source_vectors, target_vectors = load_vectors(dataset_path, config)\n",
    "\n",
    "print(source_vectors.shape)\n",
    "print(target_vectors.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "\n",
    "wandb_report = True\n",
    "\n",
    "_, _, _, _, test_sampler = get_samplers(config, source_vectors, target_vectors)\n",
    "\n",
    "N            = config['dataset']['N_MAX_SAMPLES']//1000\n",
    "TRAIN_TYPE   = config['training']['TRAIN_TYPE']\n",
    "project_name = f'NIPS_Final2Fix_{METHOD_NAME}_{DATASET_NAME}_{SOURCE_DIM}({EMB_TYPE_SOURCE})->{TARGET_DIM}({EMB_TYPE_TARGET})_{N}K_{N_REPEATS}reps_metrics'\n",
    "\n",
    "metrics_names = ['Top@1', 'Top@5', 'Top@10', 'cossim_gt', 'foscttm_norm', 'foscttm_unnorm',\n",
    "                 'distortion', 'mmd', 'bw_uvp', 'sinkhorn_divergence']\n",
    "shuffle = config['dataset']['SHUFFLE']\n",
    "alpha_values = [1.0][::-1]\n",
    "\n",
    "metrics_out = {str(np.round(alpha, 1)):[] for alpha in alpha_values}\n",
    "\n",
    "for ALPHA in alpha_values:\n",
    "    \n",
    "    config['dataset']['ALPHA'] = ALPHA \n",
    "\n",
    "        \n",
    "    print('================================')\n",
    "    print(f'Experiment for ALPHA={ALPHA}')\n",
    "    print('================================')\n",
    "    \n",
    "    _, _, train_source_sampler, train_target_sampler, _ = get_samplers(config, source_vectors, target_vectors) \n",
    "    \n",
    "    for ix in range(N_REPEATS):\n",
    "\n",
    "        SEED = random.randint(0, 10000)\n",
    "        rng = jax.random.PRNGKey(SEED)\n",
    "        config['dataset']['SEED'] = SEED\n",
    "        print('Seed: ', SEED)\n",
    "        \n",
    "        if wandb_report:\n",
    "            exp_name = f'ALPHA_{np.round(ALPHA, 1)}_repeat_{ix}_shuffled_{shuffle}'\n",
    "            wandb.init(name=exp_name, config=config, project=project_name)\n",
    "        \n",
    "\n",
    "        trained_class, metrics_dict = train_continuous(train_source_sampler, train_target_sampler,\n",
    "                                                       test_sampler, \n",
    "                                                       metrics_names, target_vectors,\n",
    "                                                       config,\n",
    "                                                       wandb_report=wandb_report,\n",
    "                                                       axis_lims=None, report_every=9000, source_vectors=source_vectors)\n",
    "        \n",
    "        metrics_out[str(np.round(ALPHA, 1))].append(metrics_dict)\n",
    "    \n",
    "        with open(f'results_{TRAIN_TYPE}_NIPS/{project_name}_10_short.pkl', 'wb') as f:\n",
    "            pickle.dump(metrics_out, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
