{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "from src.dataset_loaders import load_vectors, get_samplers\n",
    "from src.utils import get_pca_models\n",
    "from src import utils\n",
    "from src.train import train_discrete\n",
    "import wandb\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "import yaml\n",
    "import numpy as np\n",
    "import random\n",
    "import pickle\n",
    "import gc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Possible ```DATASET_NAME``` values are: ```twitter```, ```wiki-gigaword```, ```bone_marrow``` and ```muse```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "DATASET_NAME = 'muse_multi'\n",
    "\n",
    "SOURCE_DIM   = 100  \n",
    "TARGET_DIM   = 100\n",
    "\n",
    "EMB_TYPE_SOURCE = 'BP'\n",
    "EMB_TYPE_TARGET = 'BP'\n",
    "\n",
    "SOURCE_LANG     = 'en'\n",
    "TARGET_LANG     = 'es'\n",
    "\n",
    "MAX_ITERS    = 100\n",
    "VS           = 200000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "METHOD_NAME  = 'AlignGW'\n",
    "DEVICE       = 'cpu'\n",
    "\n",
    "ALPHA_INIT    = 1.0\n",
    "SEED_INIT     = 43\n",
    "COST_DISCRETE = 'cosine'\n",
    "\n",
    "config = {'dataset':dict(DATASET_NAME     = DATASET_NAME,\n",
    "                         DEVICE           = DEVICE,\n",
    "                         SOURCE_DIM       = SOURCE_DIM,\n",
    "                         TARGET_DIM       = TARGET_DIM,\n",
    "                         EMB_TYPE_SOURCE  = EMB_TYPE_SOURCE,\n",
    "                         EMB_TYPE_TARGET  = EMB_TYPE_TARGET,\n",
    "                         VS               = VS,\n",
    "                         SOURCE_LANG      = SOURCE_LANG,\n",
    "                         TARGET_LANG      = TARGET_LANG,\n",
    "                         \n",
    "                         N_MAX_SAMPLES    = 60000, #set to 6667 to get N_train=3K\n",
    "                         N_TRAIN_SAMPLES  = 6000, #We used 6000 for the others\n",
    "                         N_TEST_SAMPLES   = 512,\n",
    "                         N_EVAL           = 4,\n",
    "                         ALPHA            = ALPHA_INIT, \n",
    "                         SEED             = SEED_INIT,\n",
    "                         NORMALIZE_VECS   = False\n",
    "                          ),\n",
    "          \n",
    "          'training':dict(TRAIN_TYPE           = 'discrete',\n",
    "                          METHOD_NAME          = METHOD_NAME,\n",
    "                          MAX_ITERS            = MAX_ITERS,\n",
    "                          COST_DISCRETE        = COST_DISCRETE,\n",
    "                          ),\n",
    "\n",
    "          #===============================RegGW===============================\n",
    "          #'model_specific':dict(HIDDEN_SIZES_MLP = [512, 256, 256],\n",
    "          #                      EPS_FIT          = 0.01,\n",
    "          #                      EPS_REG          = 0.001,\n",
    "          #                      LAMBDA           = 1,\n",
    "          #                      MOVER_LR         = 1e-4\n",
    "          #                     ),\n",
    "\n",
    "          #===============================FlowGW===============================\n",
    "          #'model_specific':dict(HIDDEN_SIZES_MLP = [1024, 1024, 1024, 1024],\n",
    "          #                      EPS              = 1e-4,\n",
    "          #                      N_FREQ           = 128,\n",
    "          #                      MOVER_LR         = 1e-4,\n",
    "          #                      )\n",
    "          \n",
    "          #===============================AlignGW===============================\n",
    "          #===============================StructuredGW===============================\n",
    "          \n",
    "          'model_specific':dict(EPS = 1e-4,\n",
    "                               ),\n",
    "         }\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Loading dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([60000, 100])\n",
      "torch.Size([60000, 100])\n"
     ]
    }
   ],
   "source": [
    "dataset_path = '../datasets'\n",
    "sys.path.append(dataset_path)\n",
    "\n",
    "source_vectors, target_vectors = load_vectors(dataset_path, config)\n",
    "\n",
    "print(source_vectors.shape)\n",
    "print(target_vectors.shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'\n",
    "\n",
    "n_repeats = 10\n",
    "wandb_report = True\n",
    "\n",
    "N            = config['dataset']['N_MAX_SAMPLES']//1000\n",
    "TRAIN_TYPE   = config['training']['TRAIN_TYPE']\n",
    "#project_name = f'{METHOD_NAME}_{DATASET_NAME}_{SOURCE_DIM}({EMB_TYPE_SOURCE})->{TARGET_DIM}({EMB_TYPE_TARGET})_{N}K_{n_repeats}reps_vs({VS//1000}K)'\n",
    "project_name = f'{METHOD_NAME}_{DATASET_NAME}_{SOURCE_DIM}({SOURCE_LANG}_{EMB_TYPE_SOURCE})->{TARGET_DIM}({TARGET_LANG}_{EMB_TYPE_TARGET})_{N}K_{n_repeats}reps_vs({VS//1000}K)_final'\n",
    "\n",
    "metrics_names = ['Top@1', 'Top@5', 'Top@10', 'cossim_gt', 'inner_gw', 'foscttm']\n",
    "\n",
    "_, _, _, _, test_sampler = get_samplers(config, source_vectors, target_vectors)      \n",
    "\n",
    "alpha_values = [0.0][::-1]\n",
    "\n",
    "metrics_out = {str(np.round(alpha, 1)):[] for alpha in alpha_values}\n",
    "\n",
    "for ALPHA in alpha_values:\n",
    "    \n",
    "    config['dataset']['ALPHA'] = ALPHA \n",
    "    \n",
    "        \n",
    "    print('================================')\n",
    "    print(f'Experiment for ALPHA={ALPHA}')\n",
    "    print('================================')\n",
    "    \n",
    "    for ix in range(n_repeats):\n",
    "        \n",
    "        if wandb_report:\n",
    "            exp_name = f'ALPHA_{np.round(ALPHA, 1)}_repeat_{ix}'\n",
    "            wandb.init(name=exp_name, config=config, project=project_name)\n",
    "            \n",
    "        SEED = random.randint(0, 10000)\n",
    "        config['dataset']['SEED'] = SEED\n",
    "        print('Seed: ', SEED)\n",
    "        \n",
    "        source_vectors, target_vectors, train_source_sampler, train_target_sampler, _ = get_samplers(config, source_vectors, target_vectors) \n",
    "        \n",
    "        trained_class, metrics_dict = train_discrete(train_source_sampler, train_target_sampler,\n",
    "                                                     test_sampler, \n",
    "                                                     metrics_names, target_vectors,\n",
    "                                                     config,\n",
    "                                                     wandb_report=wandb_report,\n",
    "                                                     axis_lims=None, report_every=20)\n",
    "        \n",
    "        metrics_out[str(np.round(ALPHA, 1))].append(metrics_dict)\n",
    "\n",
    "        with open(f'results_{TRAIN_TYPE}_final/{project_name}_00.pkl', 'wb') as f:\n",
    "            pickle.dump(metrics_out, f)\n",
    "\n",
    "        gc.collect()"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
