{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "from src.dataset_loaders import load_vectors, get_samplers\n",
    "from src.utils import get_pca_models\n",
    "from src import utils\n",
    "from src.train import train_continuous\n",
    "import wandb\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import yaml\n",
    "import numpy as np\n",
    "import pickle\n",
    "import random\n",
    "import jax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Possible ```DATASET_NAME``` values are: ```twitter```, ```wiki-gigaword```, ```bone_marrow```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_NAME = 'twitter'\n",
    "\n",
    "SOURCE_DIM   = 25  \n",
    "TARGET_DIM   = 50\n",
    "\n",
    "EMB_TYPE_SOURCE = 'glove'\n",
    "EMB_TYPE_TARGET = 'glove'\n",
    "\n",
    "SOURCE_LANG     = 'en'\n",
    "TARGET_LANG     = 'fr'\n",
    "\n",
    "#VS              = 200000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "METHOD_NAME  = 'CycleGW'\n",
    "DEVICE       = 'cuda:0'\n",
    "\n",
    "ALPHA_INIT    = 1.0\n",
    "SEED_INIT     = 43\n",
    "#COST_DISCRETE = 'cosine'\n",
    "\n",
    "config = {'dataset':dict(DATASET_NAME     = DATASET_NAME,\n",
    "                         DEVICE           = DEVICE,\n",
    "                         SOURCE_DIM       = SOURCE_DIM,\n",
    "                         TARGET_DIM       = TARGET_DIM,\n",
    "                         EMB_TYPE_SOURCE  = EMB_TYPE_SOURCE,\n",
    "                         EMB_TYPE_TARGET  = EMB_TYPE_TARGET,\n",
    "                         SOURCE_LANG      = SOURCE_LANG,\n",
    "                         TARGET_LANG      = TARGET_LANG,\n",
    "                         \n",
    "                         #VS               = VS,\n",
    "                         N_MAX_SAMPLES    = 400000, \n",
    "                         N_TRAIN_SAMPLES  = 360000, \n",
    "                         N_TEST_SAMPLES   = 512,\n",
    "                         N_EVAL           = 4,\n",
    "                         ALPHA            = ALPHA_INIT, \n",
    "                         BATCH_SIZE_TRAIN = 512,\n",
    "                         BATCH_SIZE_TEST  = 512,\n",
    "                         SEED             = SEED_INIT,\n",
    "                         NORMALIZE_VECS   = False,\n",
    "                         SHUFFLE          = True\n",
    "                          ),\n",
    "          \n",
    "          'training':dict(TRAIN_TYPE      = 'continuous',\n",
    "                          METHOD_NAME     = METHOD_NAME,\n",
    "                          N_EPOCHS        = 100,\n",
    "                          #COST_DISCRETE   = COST_DISCRETE,\n",
    "                          ),\n",
    "\n",
    "          #===============================CycleGW===============================\n",
    "          'model_specific':dict(HIDDEN_SIZES_MLP = [512]*1,\n",
    "                                EPS              = 5e-4, #5e-3\n",
    "                                F_LR             = 1e-3, #1e-3\n",
    "                                G_LR             = 1e-3,#1e-3\n",
    "                                REG              = 0.1,#0.1\n",
    "                                SIGMAS           = None,\n",
    "                                TAKE_MEDIAN      = True,\n",
    "                                KERNEL_TYPE      = 'sinkhorn'\n",
    "                               )  \n",
    "          #===============================RegGW===============================\n",
    "          #'model_specific':dict(HIDDEN_SIZES_MLP= [512, 256, 256],\n",
    "          #                      EPS_FIT         = 0.01,\n",
    "          #                      EPS_REG         = 0.001,\n",
    "          #                      LAMBDA          = 1,\n",
    "          #                      MOVER_LR        = 1e-4,)\n",
    "\n",
    "          #===============================FlowGW===============================\n",
    "          #'model_specific':dict(HIDDEN_SIZES_MLP = [1024, 1024, 1024, 1024],\n",
    "          #                      EPS              = 1e-4,\n",
    "          #                      N_FREQ           = 128,\n",
    "          #                      MOVER_LR         = 1e-4)\n",
    "\n",
    "          #===============================EntropicGW===============================\n",
    "          #'model_specific':dict(COST_ITERS      = 1,\n",
    "          #                      CRITIC_ITERS    = 1,\n",
    "          #                      COST_LR         = 1e-4,\n",
    "          #                      CRITIC_LR       = 1e-4,\n",
    "          #                      HIDDEN_SIZES_MLP= [512, 512, 512, 512],\n",
    "          #                      EPS             = 1e-3)\n",
    "          \n",
    "          #===============================NeuralGW===============================\n",
    "          #'model_specific':dict(COST_ITERS      = 1,\n",
    "          #                      MOVER_ITERS     = 10,\n",
    "          #                      CRITIC_ITERS    = 1,\n",
    "          #                      REG_CRITIC      = 0.1,\n",
    "          #                      COST_LR         = 1e-4,\n",
    "          #                      MOVER_LR        = 1e-4, \n",
    "          #                      CRITIC_LR       = 1e-4,\n",
    "          #                      HIDDEN_SIZES_MLP= [512, 512, 512, 512],)\n",
    "           }\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Loading dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading twitter_glove_25 to source...\n",
      "Loading twitter_glove_50 to target...\n",
      "torch.Size([400000, 25])\n",
      "torch.Size([400000, 50])\n"
     ]
    }
   ],
   "source": [
    "dataset_path = '../datasets'\n",
    "sys.path.append(dataset_path)\n",
    "\n",
    "source_vectors, target_vectors = load_vectors(dataset_path, config)\n",
    "\n",
    "print(source_vectors.shape)\n",
    "print(target_vectors.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "\n",
    "wandb_report = True\n",
    "\n",
    "_, _, _, _, test_sampler = get_samplers(config, source_vectors, target_vectors)\n",
    "\n",
    "n_repeats    = 1\n",
    "N            = config['dataset']['N_MAX_SAMPLES']//1000\n",
    "TRAIN_TYPE   = config['training']['TRAIN_TYPE']\n",
    "project_name = f'{METHOD_NAME}_{DATASET_NAME}_{SOURCE_DIM}({EMB_TYPE_SOURCE})->{TARGET_DIM}({EMB_TYPE_TARGET})_{N}K_{n_repeats}reps_final2'\n",
    "#project_name = f'{METHOD_NAME}_{DATASET_NAME}_{SOURCE_DIM}({SOURCE_LANG}_{EMB_TYPE_SOURCE})->{TARGET_DIM}({TARGET_LANG}_{EMB_TYPE_TARGET})_{N}K_{n_repeats}reps'\n",
    "\n",
    "metrics_names = ['Top@1', 'Top@5', 'Top@10', 'cossim_gt', 'inner_gw', 'foscttm']\n",
    "shuffle = config['dataset']['SHUFFLE']\n",
    "alpha_values = [1.0][::-1]\n",
    "\n",
    "metrics_out = {str(np.round(alpha, 1)):[] for alpha in alpha_values}\n",
    "\n",
    "for ALPHA in alpha_values:\n",
    "    \n",
    "    config['dataset']['ALPHA'] = ALPHA \n",
    "\n",
    "        \n",
    "    print('================================')\n",
    "    print(f'Experiment for ALPHA={ALPHA}')\n",
    "    print('================================')\n",
    "    \n",
    "    source_vectors, target_vectors, train_source_sampler, train_target_sampler, _ = get_samplers(config, source_vectors, target_vectors) \n",
    "    \n",
    "    for ix in range(n_repeats):\n",
    "\n",
    "        SEED = random.randint(0, 10000)\n",
    "        rng = jax.random.PRNGKey(SEED)\n",
    "        config['dataset']['SEED'] = SEED\n",
    "        print('Seed: ', SEED)\n",
    "        \n",
    "        if wandb_report:\n",
    "            exp_name = f'ALPHA_{np.round(ALPHA, 1)}_repeat_{ix}_shuffled_{shuffle}'\n",
    "            wandb.init(name=exp_name, config=config, project=project_name)\n",
    "        \n",
    "        #trained_class, metrics_dict = train_continuous(train_source_sampler, train_target_sampler,\n",
    "        #                                               test_sampler, \n",
    "        #                                               metrics_names, target_vectors,\n",
    "        #                                               config,\n",
    "        #                                               wandb_report=wandb_report,\n",
    "        #                                               axis_lims=None, report_every=5)\n",
    "\n",
    "        trained_class, metrics_dict = train_continuous(train_source_sampler, train_target_sampler,\n",
    "                                                       test_sampler, \n",
    "                                                       metrics_names, target_vectors,\n",
    "                                                       config,\n",
    "                                                       wandb_report=wandb_report,\n",
    "                                                       axis_lims=None, report_every=10, source_vectors=source_vectors)\n",
    "        \n",
    "        metrics_out[str(np.round(ALPHA, 1))].append(metrics_dict)\n",
    "    \n",
    "        with open(f'results_{TRAIN_TYPE}_final/{project_name}_10_05_00.pkl', 'wb') as f:\n",
    "            pickle.dump(metrics_out, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
