{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "896e1f4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import multiprocessing as mp\n",
    "from functools import partial\n",
    "from multiprocess import Pool\n",
    "from typing import Any, List, Tuple\n",
    "import os\n",
    "import sys\n",
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "import faiss\n",
    "import fsspec\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from tqdm import tqdm, trange\n",
    "import copy\n",
    "from pathlib import Path\n",
    "\n",
    "sys.path.append('../')\n",
    "from lib.metrics import utils\n",
    "\n",
    "k = 30\n",
    "\n",
    "out_dir = Path('./datacomp_preds')\n",
    "out_dir.mkdir(exist_ok = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7495b4c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_embedding_helper(\n",
    "    fs_root: Tuple[Any, str]\n",
    ") -> np.ndarray:\n",
    "\n",
    "    fs, path_root = fs_root\n",
    "    embed = np.load(fs.open(f\"{path_root}.npz\"))\n",
    "    parquet = pd.read_parquet(fs.open(f'{path_root}.parquet'))\n",
    "    idx = parquet['uid'].values\n",
    "    return idx, embed[\"l14_img\"], embed[\"l14_txt\"]\n",
    "\n",
    "\n",
    "def load_embedding(\n",
    "    paths: List[Tuple[Any, str]],\n",
    "    n_workers: int = 10,\n",
    ") -> np.ndarray:\n",
    "    mp.set_start_method(\"spawn\", force=True)\n",
    "    print(\"start loading embedding\")\n",
    "    worker = partial(\n",
    "        load_embedding_helper,\n",
    "    )\n",
    "\n",
    "    with Pool(n_workers) as pool:\n",
    "        embeds = [\n",
    "            res\n",
    "            for res in tqdm(\n",
    "                pool.imap(worker, paths), total=len(paths)\n",
    "            )  # imap so that it can be reproduced\n",
    "            if len(res) > 0\n",
    "        ]\n",
    "    return np.concatenate([i[0] for i in embeds]), np.vstack([i[1] for i in embeds]), np.vstack([i[2] for i in embeds])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "439a43cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata_dir = '/data/FOLDER/FOLDER/data/datacomp/metadata'\n",
    "\n",
    "fs, url = fsspec.core.url_to_fs(metadata_dir)\n",
    "paths = [(fs, str(x.split(\".parquet\")[0])) for x in fs.ls(url) if \".parquet\" in x]\n",
    "\n",
    "uids, img_embeds, txt_embeds = load_embedding(\n",
    "    paths,\n",
    "    n_workers=4,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44fe18b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_indices_in_compr = np.random.choice(np.arange(len(txt_embeds)), 100_000, replace = False)\n",
    "\n",
    "emb_txt_tr = normalize(txt_embeds[train_indices_in_compr, :]) \n",
    "emb_img_tr = normalize(img_embeds[train_indices_in_compr, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa9604ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_txt = faiss.IndexFlatIP(emb_txt_tr.shape[1])\n",
    "index_img = faiss.IndexFlatIP(emb_img_tr.shape[1])\n",
    "dists_tr = 1 - (emb_txt_tr * emb_img_tr).sum(axis = 1)\n",
    "\n",
    "index_txt.add(emb_txt_tr)\n",
    "index_img.add(emb_img_tr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4418f4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_ns, I_ns = index_img.search(img_embeds, k + 1)\n",
    "D_ms, I_ms = index_txt.search(txt_embeds, k + 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06892967",
   "metadata": {},
   "outputs": [],
   "source": [
    "logs = []\n",
    "for sample_idx in trange(len(uids), total = len(uids)):\n",
    "    img_embed = img_embeds[sample_idx, None]\n",
    "    text_embed = txt_embeds[sample_idx, None]\n",
    "    d1 = 1 - np.dot(img_embed.flatten(), text_embed.flatten())\n",
    "\n",
    "    # d_n\n",
    "    D_n, I_n = D_ns[sample_idx], I_ns[sample_idx]\n",
    "    if sample_idx in train_indices_in_compr:\n",
    "        I_n = I_n[1:]\n",
    "        D_n = D_n[1:]\n",
    "    else:\n",
    "        I_n = I_n[:-1]\n",
    "        D_n = D_n[:-1]\n",
    "    y_n = emb_txt_tr[I_n]\n",
    "\n",
    "    D_n = -D_n\n",
    "    dists_n = 1 - (text_embed * y_n).sum(axis = 1)\n",
    "\n",
    "    # d_m\n",
    "    D_m, I_m = D_ms[sample_idx], I_ms[sample_idx]\n",
    "    if sample_idx in train_indices_in_compr:\n",
    "        I_m = I_m[1:]\n",
    "        D_m = D_m[1:]\n",
    "    else:\n",
    "        I_m = I_m[:-1]\n",
    "        D_m = D_m[:-1]\n",
    "    x_m = emb_img_tr[I_m]\n",
    "    \n",
    "    D_m = -D_m\n",
    "    dists_m = 1 - (img_embed * x_m).sum(axis = 1)\n",
    "\n",
    "    logs.append({\n",
    "        'idx': sample_idx,\n",
    "        'uid': uids[sample_idx],\n",
    "        'd_1': d1.item(),\n",
    "        'dists_n': dists_n,\n",
    "        'D_n': D_n.flatten(),\n",
    "        'dists_tr_n': dists_tr[I_n],\n",
    "        'dists_m': dists_m,\n",
    "        'D_m': D_m.flatten(),\n",
    "        'dists_tr_m': dists_tr[I_m]\n",
    "    })\n",
    "\n",
    "df = pd.DataFrame(logs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f95d5e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get subset of images we've downloaded\n",
    "shard_dir = Path('/data/FOLDER/FOLDER/data/datacomp/shards/')\n",
    "valid_uids = []\n",
    "\n",
    "for i in shard_dir.glob('**/*.parquet'):\n",
    "    df_i = pd.read_parquet(i)\n",
    "    valid_uids += df_i.query('status == \"success\"')['uid'].values.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "944edcbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparam_settings = [{\n",
    "    'beta': 5,\n",
    "    'gamma': 5,\n",
    "    'tau_1_n': 0.5,\n",
    "    'tau_2_n': 10,\n",
    "    'tau_1_m': 0.5,\n",
    "    'tau_2_m': 10,\n",
    "},{\n",
    "    'beta': 0,\n",
    "    'gamma': 0,\n",
    "    'tau_1_n': 0,\n",
    "    'tau_2_n': 0,\n",
    "    'tau_1_m': 0,\n",
    "    'tau_2_m': 0\n",
    "}]\n",
    "\n",
    "n = 3_500_000\n",
    "\n",
    "for c, hparam_dict in enumerate(hparam_settings):\n",
    "    score_df = copy.deepcopy(df)\n",
    "    score_df['score'] = utils.calc_scores_given_hparams_vectorized(score_df, hparam_dict)\n",
    "    score_df.to_pickle(out_dir/f'score_df_{c}.pkl')\n",
    "    \n",
    "    score_df = score_df.set_index(['uid']).loc[valid_uids].reset_index().sort_values(by = 'score', ascending = True).iloc[:n]\n",
    "    uids = score_df['uid'].values.tolist()\n",
    "    \n",
    "    processed_uids = np.array([(int(uid[:16], 16), int(uid[16:32], 16)) for uid in uids], np.dtype(\"u8,u8\"))\n",
    "    processed_uids.sort()\n",
    "    np.save(out_dir/f'uids_{c}.npy', processed_uids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdadf824",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
