{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import argparse\n",
    "import pprint\n",
    "import torch\n",
    "\n",
    "from config import get_cluster_configs, load_configs, parse_config_arg\n",
    "from two_step_zoo import get_clustering_module, get_loaders, get_clustering_trainer, get_evaluator, Writer, get_clusterer,get_id_estimator\n",
    "from two_step_zoo.datasets.loaders import get_loaders_from_config\n",
    "\n",
    "import pdb\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "import copy\n",
    "\n",
    "from sklearn.cluster import AgglomerativeClustering, OPTICS, MiniBatchKMeans\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "from heapq import heappush, heappop\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "from collections import defaultdict\n",
    "\n",
    "import itertools\n",
    "import os\n",
    "\n",
    "from id_clusterer_helper import pickle_exists,save_pickle,load_pickle,id_variance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import fcns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dist_index(x,y): \n",
    "    if x == y: print(\"SAME\", x)\n",
    "    if x == -1 or y == -1: return -1\n",
    "    assert x != y\n",
    "    if x > y: x,y = y,x\n",
    "    return x*n + y - ((x + 2) * (x + 1)) // 2\n",
    "\n",
    "def get_dist(x,y, dists): \n",
    "    return dists[dist_index(x,y)]\n",
    "\n",
    "def get_nn_dists(oidxs, k, dists):\n",
    "    nn_dists = np.array([[dists[dist_index(idx,j)] for j in oidxs if j != idx] for idx in oidxs])\n",
    "    nn_neighbours = np.array([[j for j in oidxs if j != idx] for idx in oidxs])\n",
    "\n",
    "    arg_part = np.argpartition(nn_dists, k)\n",
    "    nn_dists = np.take_along_axis(nn_dists, arg_part, axis=-1)[:,:k]\n",
    "    nn_neighbours = np.take_along_axis(nn_neighbours, arg_part, axis=-1)[:,:k]\n",
    "    \n",
    "    arg_sort = np.argsort(nn_dists, axis=-1)\n",
    "    return np.take_along_axis(nn_dists, arg_sort, axis=-1), np.take_along_axis(nn_neighbours, arg_sort, axis=-1)\n",
    "\n",
    "def calculate_id(idxs, dists, second_idx=0, return_idx=False, k=10):\n",
    "    k = min(k, len(idxs)-2)\n",
    "    \n",
    "    nn_dists,nn_neighbours = get_nn_dists(idxs, k, dists)\n",
    "\n",
    "    d = np.log(nn_dists[:, k - 1: k] / nn_dists[:, 0:k - 1])\n",
    "    inv_mle = np.sum(d, -1) / (k-1)\n",
    "\n",
    "    if return_idx:\n",
    "        return (second_idx, (1 / inv_mle.mean()))\n",
    "    return (1 / inv_mle.mean()),nn_neighbours\n",
    "\n",
    "def id_variance(clusters, dists):\n",
    "    ids = [calculate_id(cluster, dists)[0] for cluster in clusters]\n",
    "    bs = len(ids)\n",
    "    mean_id = sum(ids) / len(ids)\n",
    "    return sum( [(mean_id-id)**2 for id in ids] ) / (bs-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Init code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"cifar10\"\n",
    "run_name = \"0330_1000_cap_cifar10_5_6\"\n",
    "load_from = \"iter_770_0330_weightedstd_full_0404_cifar10_5_6.pickle\"\n",
    "save_partition_name=\"tiered_cifar10\"\n",
    "norm = 255.\n",
    "num_clusters = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gae_cfg, de_cfg, shared_cfg, cluster_cfg = get_cluster_configs(\n",
    "    dataset=dataset,\n",
    "    generalized_autoencoder=\"avb\",\n",
    "    density_estimator=\"vae\"\n",
    ")\n",
    "\n",
    "train_loader, valid_loader, test_loader = get_loaders_from_config(shared_cfg, \"cpu\")\n",
    "\n",
    "tdata = train_loader.dataset.inputs.cpu()\n",
    "tlabs = train_loader.dataset.targets.cpu()\n",
    "\n",
    "n = tlabs.shape[0]\n",
    "\n",
    "vdata = valid_loader.dataset.inputs.cpu()\n",
    "vlabs = valid_loader.dataset.targets.cpu()\n",
    "\n",
    "tfeats = tdata.reshape(tdata.shape[0],-1)/norm\n",
    "vfeats = vdata.reshape(vdata.shape[0],-1)/norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sections = 2000\n",
    "vfeats = vfeats.cuda()\n",
    "tfeats = tfeats.cuda()\n",
    "num_entries = tfeats.shape[0]\n",
    "quotient = num_entries // sections\n",
    "remainder = num_entries % sections\n",
    "train_val_dists = []\n",
    "for i in tqdm(range(sections)):\n",
    "    train_val_dists.append(  ((tfeats[i*quotient:(i+1)*quotient,None,:]-vfeats[None,:,:])**2).sum(-1) )\n",
    "if remainder != 0:\n",
    "    train_val_dists.append(  ((tfeats[-remainder:,None,:]-vfeats[None,:,:])**2).sum(-1) )\n",
    "\n",
    "train_val_dists = torch.cat(train_val_dists).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Loading pdists from {dataset}_pdists\")\n",
    "dists = load_pickle(f'{dataset}_pdists')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clusters_save = load_pickle(f'{run_name}_final' if load_from is None else load_from)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clusters = [c for c in clusters_save[\"clusters\"] if len(c) > 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[len(c) for c in clusters], len(clusters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_clusters = len(clusters)\n",
    "partitions = [{split: [] for split in [\"train\", \"valid\", \"test\"]} for cidx in range(num_clusters)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for cidx,cluster in enumerate(clusters):\n",
    "    partitions[cidx][\"train\"] = cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for val_idx in tqdm(range(train_val_dists.shape[1])):\n",
    "\n",
    "    # Get mean distance\n",
    "    min_distance = math.inf\n",
    "    current_match = 1000\n",
    "\n",
    "    for cidx in range(num_clusters):\n",
    "        candidate_distance = train_val_dists[torch.tensor(partitions[cidx][\"train\"]),torch.tensor(val_idx).repeat(len(partitions[cidx][\"train\"]))].mean()\n",
    "        if candidate_distance < min_distance:\n",
    "            min_distance = candidate_distance\n",
    "            current_match = cidx\n",
    "    \n",
    "    partitions[current_match][\"valid\"].append(val_idx)\n",
    "    partitions[current_match][\"test\"].append(val_idx)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[len(c[\"valid\"]) for c in partitions]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_pickle(save_partition_name, partitions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cluster Vis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "partitions = load_pickle(\"cap6k_cifar\")\n",
    "run_name = \"cap6k_cifar\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nr,nc = 4,4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for cluster_idx,cluster in enumerate(partitions):\n",
    "    f, axarr = plt.subplots(nr,nc)\n",
    "    f.suptitle(f\"Cluster: {cluster_idx}\", fontsize=16)\n",
    "    idxs = random.sample(cluster[\"train\"].tolist(),nr*nc)\n",
    "    for i,idx in enumerate(idxs):\n",
    "        img = tdata[idx]\n",
    "        axarr[i // nr, i % nr].set_xticks([])\n",
    "        axarr[i // nr, i % nr].set_yticks([])\n",
    "        axarr[i // nr, i % nr].grid(False)\n",
    "        axarr[i // nr, i % nr].imshow(img.to(torch.int).permute(1,2,0))\n",
    "\n",
    "    f.show()\n",
    "    f.savefig(f\"./id_cluster_vis/{run_name}_{cluster_idx}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_name = \"tiered\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "db448fe85518e560c3ef83ccc8d105f9c373042f28551637af6f8c1ae46c95bb"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 ('two_step_zoo': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
