{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import pprint\n",
    "import torch\n",
    "\n",
    "from config import get_cluster_configs, load_configs, parse_config_arg\n",
    "from two_step_zoo import get_clustering_module, get_loaders, get_clustering_trainer, get_evaluator, Writer, get_clusterer,get_id_estimator\n",
    "from two_step_zoo.datasets.loaders import get_loaders_from_config\n",
    "\n",
    "import pdb\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "import copy\n",
    "\n",
    "from sklearn.cluster import AgglomerativeClustering, OPTICS, MiniBatchKMeans\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "from heapq import heappush, heappop\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "from collections import defaultdict\n",
    "\n",
    "import itertools\n",
    "import os\n",
    "\n",
    "def pickle_exists(name): \n",
    "    return os.path.exists(f'pickles/{name}.pickle')\n",
    "\n",
    "def save_pickle(name, object):\n",
    "    with open(f'pickles/{name}.pickle', 'wb') as handle: \n",
    "        pickle.dump(object, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n",
    "def load_pickle(name):\n",
    "    with open(f'pickles/{name}.pickle', 'rb') as handle:\n",
    "        object = pickle.load(handle)\n",
    "    \n",
    "    return object\n",
    "\n",
    "def dist_index(x,y): \n",
    "    if x == y: print(\"SAME\", x)\n",
    "    if x == -1 or y == -1: return -1\n",
    "    assert x != y\n",
    "    if x > y: x,y = y,x\n",
    "    return x*n + y - ((x + 2) * (x + 1)) // 2\n",
    "\n",
    "def get_dist(x,y, dists): \n",
    "    return dists[dist_index(x,y)]\n",
    "\n",
    "def get_nn_dists(oidxs, k, dists):\n",
    "    nn_dists = np.array([[dists[dist_index(idx,j)] for j in oidxs if j != idx] for idx in oidxs])\n",
    "    nn_neighbours = np.array([[j for j in oidxs if j != idx] for idx in oidxs])\n",
    "\n",
    "    arg_part = np.argpartition(nn_dists, k)\n",
    "    nn_dists = np.take_along_axis(nn_dists, arg_part, axis=-1)[:,:k]\n",
    "    nn_neighbours = np.take_along_axis(nn_neighbours, arg_part, axis=-1)[:,:k]\n",
    "    \n",
    "    arg_sort = np.argsort(nn_dists, axis=-1)\n",
    "    return np.take_along_axis(nn_dists, arg_sort, axis=-1), np.take_along_axis(nn_neighbours, arg_sort, axis=-1)\n",
    "\n",
    "def calculate_id(idxs, dists, second_idx=0, return_idx=False, k=10):\n",
    "    k = min(k, len(idxs)-2)\n",
    "    \n",
    "    nn_dists,nn_neighbours = get_nn_dists(idxs, k, dists)\n",
    "\n",
    "    d = np.log(nn_dists[:, k - 1: k] / nn_dists[:, 0:k - 1])\n",
    "    inv_mle = np.sum(d, -1) / (k-1)\n",
    "\n",
    "    if return_idx:\n",
    "        return (second_idx, (1 / inv_mle.mean()))\n",
    "    return (1 / inv_mle.mean()),nn_neighbours\n",
    "\n",
    "def id_variance(clusters, dists):\n",
    "    ids = [calculate_id(cluster, dists)[0] for cluster in clusters]\n",
    "    bs = len(ids)\n",
    "    mean_id = sum(ids) / len(ids)\n",
    "    return sum( [(mean_id-id)**2 for id in ids] ) / (bs-1)\n",
    "\n",
    "def update_id(idxs, nn_neighbours, dists, second_idx=0, return_idx=False, k=10):\n",
    "\n",
    "    nn_dists = np.array([[dists[dist_index(im_idx,j)] for j in nn_neighbours[idx]] for idx,im_idx in enumerate(idxs)])\n",
    "    \n",
    "    arg_sort = [np.unique(nn_dists[i], return_index=True) for i in range(nn_dists.shape[0])]\n",
    "    k = min(k,min([(nn[0] < 1e6).sum() for nn in arg_sort]))\n",
    "\n",
    "    arg_sort = np.stack([nn[1][:k] for nn in arg_sort])\n",
    "    nn_dists,nn_neighbours = np.take_along_axis(nn_dists, arg_sort, axis=-1), np.take_along_axis(nn_neighbours, arg_sort, axis=-1)\n",
    "\n",
    "    d = np.log(nn_dists[:, k - 1: k] / nn_dists[:, 0:k - 1])\n",
    "        \n",
    "    inv_mle = np.sum(d, -1) / (k-1)\n",
    "\n",
    "    if return_idx:\n",
    "        return (second_idx, (1 / inv_mle.mean()))\n",
    "    return (1 / inv_mle.mean()),nn_neighbours\n",
    "\n",
    "def cat_pad(tuple,cat_axis,pad_axis):\n",
    "    max_len = max([t.shape[pad_axis] for t in tuple])\n",
    "    if max_len == min([t.shape[pad_axis] for t in tuple]): return np.concatenate(tuple, axis=cat_axis)\n",
    "    return np.concatenate([np.pad(t, pad_width=((0,max_len-t.shape[pad_axis] if pad_axis == 0 else 0),\\\n",
    "        (0, max_len-t.shape[pad_axis] if pad_axis == 1 else 0)), \\\n",
    "        mode=\"constant\", constant_values=-1) for t in tuple],axis=cat_axis)\n",
    "\n",
    "def initial_clusters(idxs, num_merges=2):\n",
    "    for main_iter in tqdm(range(num_merges)):\n",
    "        next_idxs = []\n",
    "        distance_heap = []\n",
    "\n",
    "        used = set()\n",
    "        for i in (range(len(idxs)-1)):\n",
    "            for j in range(i+1, len(idxs)):\n",
    "\n",
    "                inner_dists = []\n",
    "                for first_idx in idxs[i]:\n",
    "                    for second_idx in idxs[j]:\n",
    "                        inner_dists.append(get_dist(i,j, dists))\n",
    "                heappush(distance_heap, (sum(inner_dists) / len(inner_dists), i, j))\n",
    "\n",
    "        while(len(distance_heap) > 0):\n",
    "            _,i,j = heappop(distance_heap)\n",
    "            if i not in used and j not in used:\n",
    "                next_idxs.append(idxs[i] + idxs[j])\n",
    "                used.add(i)\n",
    "                used.add(j)\n",
    "            \n",
    "            if len(idxs) - len(used) < 3+main_iter:\n",
    "                leftover = [idxs[leftover_idx] for leftover_idx in range(len(idxs)) if leftover_idx not in used]\n",
    "                next_idxs.append(list(itertools.chain.from_iterable(leftover)))\n",
    "                break\n",
    "        \n",
    "        idxs = next_idxs\n",
    "    \n",
    "    return idxs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Algo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset=\"cifar100\"\n",
    "m=5\n",
    "cap=10000\n",
    "save_graph_iter=1\n",
    "num_initial_merges=4\n",
    "run_name=\"0330_weightedstd_full\"\n",
    "norm=255. \n",
    "first_from_save=False\n",
    "print_stats=False\n",
    "save_first=True\n",
    "save_plots=True\n",
    "print_times=False\n",
    "class_prior=False\n",
    "run_name += f\"_{dataset}_{m}_{num_initial_merges}\"\n",
    "run_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gae_cfg, de_cfg, shared_cfg, cluster_cfg = get_cluster_configs(\n",
    "    dataset=dataset,\n",
    "    generalized_autoencoder=\"avb\",\n",
    "    density_estimator=\"vae\"\n",
    ")\n",
    "\n",
    "train_loader, valid_loader, test_loader = get_loaders_from_config(shared_cfg, \"cpu\")\n",
    "\n",
    "tdata = train_loader.dataset.inputs.cpu()/norm\n",
    "tlabs = train_loader.dataset.targets.cpu()\n",
    "\n",
    "feats = tdata.reshape(tdata.shape[0],-1)\n",
    "n,f = feats.shape\n",
    "\n",
    "# Calculate pairwise distances\n",
    "if pickle_exists(f'{dataset}_pdists'):\n",
    "    print(f\"Loading pdists from {dataset}_pdists\")\n",
    "    dists = load_pickle(f'{dataset}_pdists')\n",
    "else:\n",
    "    print(f\"Calculating pdists\")\n",
    "    dists = pdist(feats)\n",
    "    print(f\"Saving pdists as {dataset}_pdists\")\n",
    "    save_pickle(f'{dataset}_pdists', dists)\n",
    "\n",
    "if dists.min() < 1e-4:\n",
    "    print(\"Tiny pdists, adding an epsilon\")\n",
    "    dists = dists + 1e-4\n",
    "\n",
    "# Calculate initializations\n",
    "if pickle_exists(f'og_clusters_{run_name}'):\n",
    "\n",
    "    print(\"Loading initial clusters from\", f'og_clusters_{run_name}')\n",
    "    og_clusters = load_pickle(f'og_clusters_{run_name}')\n",
    "\n",
    "else:\n",
    "\n",
    "    print(\"Calculating initial clusters...\")\n",
    "\n",
    "    classes = torch.unique(tlabs)\n",
    "        \n",
    "    if class_prior:\n",
    "        class_to_ids = {cidx.item(): [] for cidx in classes}\n",
    "\n",
    "        for idx,tlab in enumerate(tlabs): class_to_ids[tlab.item()].append([idx])\n",
    "\n",
    "        print(f\"Id variance of class clusters: \\\n",
    "            {id_variance([list(itertools.chain.from_iterable(cidxs)) for cidxs in class_to_ids.values()], dists)}\")\n",
    "        \n",
    "        og_clusters = [initial_clusters(idxs, num_merges=num_initial_merges) for idxs in class_to_ids.values()]\n",
    "        og_clusters = list(itertools.chain.from_iterable(og_clusters))\n",
    "    else:\n",
    "        og_clusters = initial_clusters([[i] for i in range(tlabs.shape[0])], num_merges=num_initial_merges)\n",
    "\n",
    "    og_lens = [len(c) for c in og_clusters]\n",
    "    print(\"Max, min & mean cluster initial sizes\",\\\n",
    "        max(og_lens), min(og_lens), sum(og_lens) / len(og_lens))\n",
    "    \n",
    "    print(\"Saving initial clusters to\", f'og_clusters_{run_name}')\n",
    "    save_pickle(f'og_clusters_{run_name}', og_clusters)\n",
    "\n",
    "# Main clustering algo\n",
    "sections = 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if first_from_save:\n",
    "    print(f\"Loading initial iteration from pickles/first_iter_{run_name}.pickle\")\n",
    "    b = load_pickle(f\"first_iter_{run_name}.pickle\")\n",
    "\n",
    "    clusters = b[\"clusters\"]\n",
    "    id_estimates = b[\"id_estimates\"]\n",
    "    id_sum = b[\"id_sum\"]\n",
    "    cluster_cache = b[\"cluster_cache\"]\n",
    "    used_idxs = b[\"used_idxs\"]\n",
    "    combined_ids = b[\"combined_ids\"].cuda()\n",
    "    id_estimates_one = b[\"id_estimates_one\"].cuda()\n",
    "    id_estimates_two = b[\"id_estimates_two\"].cuda()\n",
    "    idx_pair_to_index = b[\"idx_pair_to_index\"]\n",
    "    idx_pairs = b[\"idx_pairs\"]\n",
    "    cluster_sizes = b[\"cluster_sizes\"]\n",
    "    cluster_sizes_one = b[\"cluster_sizes_one\"].cuda()\n",
    "    merge_cluster_lens = b[\"merge_cluster_lens\"].cuda()\n",
    "    cluster_sizes_two = b[\"cluster_sizes_two\"].cuda()\n",
    "\n",
    "    if \"first_merge_idx\" in b:\n",
    "        first_merge_idx = b[\"first_merge_idx\"]\n",
    "    else:\n",
    "        first_merge_idx=0\n",
    "\n",
    "    if \"second_merge_idx\" in b:\n",
    "        second_merge_idx = b[\"second_merge_idx\"]\n",
    "    else:\n",
    "        second_merge_idx=1\n",
    "    \n",
    "    print(f\"Using {len(clusters)} clusters\")\n",
    "\n",
    "else:\n",
    "    \n",
    "    clusters = copy.deepcopy(og_clusters)\n",
    "    clusters = [np.array(cluster) for cluster in clusters]\n",
    "    for cluster in clusters: assert len(cluster) > 3, \"All clusters must have length greater than 3\"\n",
    "\n",
    "    print(f\"Using {len(clusters)} clusters\")\n",
    "\n",
    "    id_estimates = [calculate_id(cluster, dists)[0] for cluster in clusters]\n",
    "    cluster_sizes = [len(c) for c in clusters]\n",
    "    print(\"Initial ID estimates min, max, mean:\", min(id_estimates), max(id_estimates), sum(id_estimates)/len(id_estimates))\n",
    "\n",
    "    id_sum = sum([id_estimate*clen for id_estimate,clen in zip(id_estimates, cluster_sizes)])\n",
    "    used_idxs = set([i for i in range(len(clusters))])\n",
    "    new_n = len(clusters)\n",
    "\n",
    "    cluster_cache = {}\n",
    "\n",
    "second=False\n",
    "switches=[]\n",
    "new_n = len(clusters)\n",
    "merge_cluster_sizes = []\n",
    "num_merges = []\n",
    "merge_checker = defaultdict(int)\n",
    "merges = []\n",
    "keep_in_calc_idxs = []\n",
    "\n",
    "id_estimates_maxes = []\n",
    "id_estimates_mins = []\n",
    "id_estimates_means = []\n",
    "inter_id_vars = []\n",
    "combined_ids_log = []\n",
    "bs = tlabs.shape[0]\n",
    "\n",
    "save_set = [i for i in range(15)] + [20, 25, 30]\n",
    "\n",
    "if not first_from_save:\n",
    "    combined_ids = []\n",
    "    id_estimates_one = []\n",
    "    id_estimates_two = []\n",
    "    idx_pairs = []\n",
    "    idx_pair_to_index = {}\n",
    "    cluster_sizes_two = []\n",
    "    cluster_sizes_one = []\n",
    "    merge_cluster_lens = []\n",
    "\n",
    "for main_iter in tqdm(range(new_n-m)):\n",
    "    if len(clusters) == m: break\n",
    "\n",
    "    lc = len(used_idxs)\n",
    "\n",
    "    if not second and first_from_save:\n",
    "        current_merge = (first_merge_idx,second_merge_idx)\n",
    "        current_merge_id = cluster_cache[(first_merge_idx,second_merge_idx)][0]\n",
    "        current_max_var = 10 # not used\n",
    "        combined_id=0\n",
    "    else:\n",
    "        current_max_var = -math.inf\n",
    "        current_merge = (0,0)\n",
    "        current_merge_id = math.inf\n",
    "\n",
    "    # Get optimal merge pair\n",
    "    start = time.time()\n",
    "    calcs = []\n",
    "    id_est_vec_start = time.time()\n",
    "    id_est_vec = torch.tensor([id_estimates[id] for id in used_idxs], device=\"cuda\", dtype=torch.float32)\n",
    "    cluster_lens = torch.tensor([cluster_sizes[id] for id in used_idxs], device=\"cuda\", dtype=torch.float32)\n",
    "\n",
    "    if len(keep_in_calc_idxs) > 0:\n",
    "        id_est_vec_kept_out = torch.tensor([id_estimates[id] for id in keep_in_calc_idxs], device=\"cuda\", dtype=torch.float32)\n",
    "        id_est_vec = torch.cat([id_est_vec, id_est_vec_kept_out])\n",
    "\n",
    "    if print_times: print(\"id_est_vec time:\", time.time()-id_est_vec_start)\n",
    "\n",
    "    if not first_from_save or second:\n",
    "        agg_start = time.time()\n",
    "\n",
    "        if main_iter == 0:\n",
    "            for first_idx in (used_idxs):\n",
    "                \n",
    "                for second_idx in used_idxs:\n",
    "\n",
    "                    if second_idx <= first_idx: continue\n",
    "\n",
    "                    first_cluster, second_cluster = clusters[first_idx], clusters[second_idx]\n",
    "                    combined_id,neighbours = calculate_id(np.concatenate((first_cluster,second_cluster)),dists)\n",
    "                    cluster_cache[(first_idx, second_idx)] = (combined_id,neighbours)\n",
    "\n",
    "                    idx_pair_to_index[(first_idx, second_idx)] = len(combined_ids)\n",
    "\n",
    "                    combined_ids.append(combined_id)\n",
    "                    id_estimates_one.append(id_estimates[first_idx])\n",
    "                    id_estimates_two.append(id_estimates[second_idx])\n",
    "                    cluster_sizes_one.append(cluster_sizes[first_idx])\n",
    "                    cluster_sizes_two.append(cluster_sizes[second_idx])\n",
    "                    idx_pairs.append((first_idx, second_idx))\n",
    "                    merge_cluster_lens.append(len(clusters[first_idx])+len(clusters[second_idx]))\n",
    "                \n",
    "            \n",
    "            combined_ids = torch.tensor(combined_ids, device=\"cuda\", dtype=torch.float32)\n",
    "            id_estimates_one = torch.tensor(id_estimates_one, device=\"cuda\", dtype=torch.float32)\n",
    "            id_estimates_two = torch.tensor(id_estimates_two, device=\"cuda\", dtype=torch.float32)\n",
    "            cluster_sizes_one = torch.tensor(cluster_sizes_one, device=\"cuda\", dtype=torch.float32)\n",
    "            cluster_sizes_two = torch.tensor(cluster_sizes_two, device=\"cuda\", dtype=torch.float32)\n",
    "            merge_cluster_lens = torch.tensor(merge_cluster_lens, device=\"cuda\", dtype=torch.float32)\n",
    "\n",
    "        if print_times: print(\"Aggregate ids_time time:\", time.time()-agg_start)\n",
    "        calc_start = time.time()\n",
    "        candidate_id_sums = id_sum - id_estimates_one*cluster_sizes_one - id_estimates_two*cluster_sizes_two - combined_ids*merge_cluster_lens\n",
    "        candidate_id_means = candidate_id_sums / bs\n",
    "\n",
    "        long_op = time.time()\n",
    "        num_entries = candidate_id_means.shape[0]\n",
    "        quotient = num_entries // sections\n",
    "        remainder = num_entries % sections\n",
    "        candidate_id_vars = []\n",
    "        for i in (range(sections)):\n",
    "            candidate_id_vars.append( (cluster_lens[None,:]*((id_est_vec[None,:]-candidate_id_means[i*quotient:(i+1)*quotient,None])**2)).sum(axis=1) )\n",
    "        if remainder != 0:\n",
    "            candidate_id_vars.append( ((id_est_vec[None,:]-candidate_id_means[-remainder:,None])**2).sum(axis=1) )\n",
    "        candidate_id_vars = torch.cat(candidate_id_vars)\n",
    "        if print_times: print(\"Long op time:\", time.time()-long_op)\n",
    "\n",
    "        candidate_id_vars -= ( cluster_sizes_one*((candidate_id_means-id_estimates_one)**2) + cluster_sizes_two*((candidate_id_means-id_estimates_two)**2) )\n",
    "        candidate_id_vars += ((candidate_id_means-combined_ids)**2)*merge_cluster_lens\n",
    "        candidate_id_vars /= bs\n",
    "        max_index = torch.argmax(torch.nan_to_num(candidate_id_vars,nan=0)).item()\n",
    "\n",
    "        current_max_var = candidate_id_vars[max_index].item()\n",
    "        current_merge = idx_pairs[max_index]\n",
    "        current_merge_id = combined_ids[max_index].item()\n",
    "        if print_times: print(\"Calc time:\", time.time()-calc_start)\n",
    "\n",
    "    if (second == False and not  first_from_save and save_first) or len(used_idxs) in save_set: \n",
    "        to_save = {\n",
    "            \"clusters\": clusters,\n",
    "            \"id_estimates\": id_estimates,\n",
    "            \"id_sum\": id_sum,\n",
    "            \"cluster_cache\": cluster_cache,\n",
    "            \"used_idxs\": used_idxs,\n",
    "            \"first_merge_idx\": current_merge[0],\n",
    "            \"second_merge_idx\": current_merge[1],\n",
    "            \"combined_ids\": combined_ids.cpu(),\n",
    "            \"id_estimates_one\": id_estimates_one.cpu(),\n",
    "            \"id_estimates_two\": id_estimates_two.cpu(),\n",
    "            \"idx_pair_to_index\": idx_pair_to_index,\n",
    "            \"idx_pairs\": idx_pairs,\n",
    "            \"cluster_sizes_one\": cluster_sizes_one.cpu(),\n",
    "            \"merge_cluster_lens\": merge_cluster_lens.cpu(),\n",
    "            \"cluster_sizes_two\": cluster_sizes_two.cpu(),\n",
    "            \"cluster_sizes\": cluster_sizes\n",
    "        }\n",
    "        if second:\n",
    "            save_pickle(f'iter_{main_iter}_{run_name}.pickle', to_save)\n",
    "        else:\n",
    "            save_pickle(f'first_iter_{run_name}.pickle', to_save)\n",
    "\n",
    "    if print_times: print(\"Main loop\", time.time()-start)\n",
    "\n",
    "    if torch.nansum(candidate_id_vars) == 0:\n",
    "        print(\"No more merge candidates\")\n",
    "        break\n",
    "\n",
    "    start = time.time()\n",
    "\n",
    "    first_merge_idx, second_merge_idx = current_merge\n",
    "    switches.append((first_merge_idx, second_merge_idx))\n",
    "    used_idxs.remove(second_merge_idx)\n",
    "\n",
    "    if print_stats: print(\"Merging\", first_merge_idx, \"into\", second_merge_idx, \"size1:\", len(clusters[first_merge_idx]), \"size2:\", len(clusters[second_merge_idx]), \"max var:\", current_max_var)\n",
    "        \n",
    "    if save_plots and main_iter % save_graph_iter == 0:\n",
    "        id_ests=[]\n",
    "        for id in used_idxs:\n",
    "            id_ests.append((len(clusters[id]), id_estimates[id]))\n",
    "        plt.scatter([i[0] for i in id_ests], [i[1] for i in id_ests])\n",
    "        plt.scatter([len(clusters[first_merge_idx]), len(clusters[second_merge_idx])], [(id_estimates[first_merge_idx]), (id_estimates[second_merge_idx])], color=\"red\")\n",
    "        plt.savefig(f\"./id_run_saves/{run_name}_{main_iter}.png\")\n",
    "        plt.close()\n",
    "    \n",
    "    # Save stats\n",
    "    id_estimates_non_empty = [id_estimates[id] for id in used_idxs]\n",
    "\n",
    "    if min(id_estimates_non_empty) < 1e-3:\n",
    "        print(\"Really low id estimate:\", min(id_estimates_non_empty))\n",
    "    \n",
    "    if max(id_estimates_non_empty) > 100:\n",
    "        print(\"Really high id estimate:\", max(id_estimates_non_empty))\n",
    "\n",
    "    id_estimates_maxes.append(max(id_estimates_non_empty))\n",
    "    id_estimates_mins.append(min(id_estimates_non_empty))\n",
    "    id_estimates_means.append(sum(id_estimates_non_empty)/len(id_estimates_non_empty))\n",
    "    inter_id_vars.append(current_max_var)\n",
    "    combined_ids_log.append(current_merge_id)\n",
    "\n",
    "    clusters[first_merge_idx] = np.concatenate((clusters[first_merge_idx],clusters[second_merge_idx]))\n",
    "    clusters[second_merge_idx] = []\n",
    "\n",
    "    merge_cluster_sizes.append(len(clusters[first_merge_idx]))\n",
    "    num_merges.append(max(merge_checker[first_merge_idx], merge_checker[second_merge_idx]))\n",
    "\n",
    "    merge_checker[first_merge_idx]= max(merge_checker[first_merge_idx], merge_checker[second_merge_idx]) + 1\n",
    "\n",
    "    merges.append((first_merge_idx, second_merge_idx))\n",
    "    id_sum -= (id_estimates[first_merge_idx] + id_estimates[second_merge_idx])\n",
    "    id_sum += current_merge_id\n",
    "\n",
    "    id_estimates[first_merge_idx] = current_merge_id\n",
    "    cluster_sizes[first_merge_idx] += cluster_sizes[second_merge_idx]\n",
    "    cluster_sizes[second_merge_idx] = 0\n",
    "\n",
    "    base_merging_neighours = cluster_cache[(first_merge_idx, second_merge_idx)][1]\n",
    "    if print_times: print(\"Initial merge\", time.time()-start)\n",
    "\n",
    "    combined_ids[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    id_estimates_one[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    id_estimates_two[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    id_estimates_one[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    id_estimates_two[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    cluster_sizes_one[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    cluster_sizes_two[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    merge_cluster_lens[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    \n",
    "\n",
    "    if cap != -1 and clusters[first_merge_idx].shape[0] >= cap:\n",
    "        print(\"Removing cluster of size:\", clusters[first_merge_idx].shape[0])\n",
    "        used_idxs.remove(first_merge_idx)\n",
    "        keep_in_calc_idxs.append(first_merge_idx)\n",
    "\n",
    "    start = time.time()\n",
    "    for idx in used_idxs:\n",
    "        if idx == first_merge_idx: continue\n",
    "\n",
    "        if cap != -1 and clusters[first_merge_idx].shape[0] >= cap:\n",
    "            if idx < first_merge_idx:\n",
    "\n",
    "                combined_ids[idx_pair_to_index[(idx, first_merge_idx)]] = torch.nan\n",
    "                id_estimates_one[idx_pair_to_index[(idx, first_merge_idx)]] = torch.nan\n",
    "                id_estimates_two[idx_pair_to_index[(idx, first_merge_idx)]] = torch.nan\n",
    "\n",
    "                combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "                id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "                id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            \n",
    "            elif idx > first_merge_idx and idx < second_merge_idx:\n",
    "\n",
    "                combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan\n",
    "                id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan\n",
    "                id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan\n",
    "\n",
    "                combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "                id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "                id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            \n",
    "            else:\n",
    "\n",
    "                combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan\n",
    "                id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan\n",
    "                id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan\n",
    "\n",
    "                combined_ids[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "                id_estimates_one[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "                id_estimates_two[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "        \n",
    "        bs = clusters[idx].shape[0]\n",
    "\n",
    "        start_update_id_time = time.time()\n",
    "        if idx < first_merge_idx:\n",
    "            to_merge_low = cluster_cache[(idx, first_merge_idx)][1]\n",
    "            to_merge_high = cluster_cache[(idx, second_merge_idx)][1]\n",
    "\n",
    "            merging_additions = cat_pad((to_merge_low[bs:], to_merge_high[bs:]), pad_axis=1, cat_axis=0)\n",
    "            new_merging_neighbours = cat_pad((base_merging_neighours, merging_additions), pad_axis=1, cat_axis=1)\n",
    "            new_neighbours = cat_pad((to_merge_low[:bs], to_merge_high[:bs]), pad_axis=1, cat_axis=1)\n",
    "            updated_neighbours = cat_pad((new_neighbours, new_merging_neighbours), pad_axis=1, cat_axis=0)\n",
    "            \n",
    "            cluster_cache[(idx, first_merge_idx)] = update_id(np.concatenate((clusters[idx], clusters[first_merge_idx]), axis=0),updated_neighbours, dists)\n",
    "        \n",
    "            combined_ids[idx_pair_to_index[(idx, first_merge_idx)]] = cluster_cache[(idx, first_merge_idx)][0]\n",
    "            # doesn't change id_estimates_one[idx_pair_to_index[(idx, first_merge_idx)]]\n",
    "            id_estimates_two[idx_pair_to_index[(idx, first_merge_idx)]] = current_merge_id\n",
    "\n",
    "            combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "\n",
    "            cluster_sizes_two[idx_pair_to_index[(idx, first_merge_idx)]] = cluster_sizes[idx] + cluster_sizes[first_merge_idx]\n",
    "            merge_cluster_lens[idx_pair_to_index[(idx, first_merge_idx)]] = cluster_sizes[first_merge_idx]\n",
    "\n",
    "            cluster_sizes_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            cluster_sizes_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            merge_cluster_lens[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "\n",
    "        elif idx > first_merge_idx and idx < second_merge_idx:\n",
    "        \n",
    "            to_merge_low = cluster_cache[(first_merge_idx, idx)][1]\n",
    "            to_merge_high = cluster_cache[(idx, second_merge_idx)][1]\n",
    "\n",
    "            merging_additions = cat_pad((to_merge_low[:-bs], to_merge_high[bs:]), pad_axis=1, cat_axis=0)\n",
    "            new_merging_neighbours = cat_pad((base_merging_neighours, merging_additions), pad_axis=1, cat_axis=1)\n",
    "            new_neighbours = cat_pad((to_merge_low[-bs:], to_merge_high[:bs]), pad_axis=1, cat_axis=1)\n",
    "            updated_neighbours = cat_pad((new_merging_neighbours, new_neighbours), pad_axis=1, cat_axis=0)\n",
    "            \n",
    "            cluster_cache[(first_merge_idx, idx)] = update_id(np.concatenate((clusters[first_merge_idx], clusters[idx]), axis=0),updated_neighbours, dists)\n",
    "            \n",
    "            combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_cache[(first_merge_idx, idx)][0]\n",
    "            id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = current_merge_id\n",
    "                # doesn't change id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]] = current_merge_id\n",
    "\n",
    "            combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "\n",
    "            cluster_sizes_two[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_sizes[idx] + cluster_sizes[first_merge_idx]\n",
    "            merge_cluster_lens[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_sizes[first_merge_idx]\n",
    "\n",
    "            cluster_sizes_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            cluster_sizes_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            merge_cluster_lens[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            \n",
    "        else:\n",
    "            to_merge_low = cluster_cache[(first_merge_idx, idx)][1]\n",
    "            to_merge_high = cluster_cache[(second_merge_idx, idx)][1]\n",
    "\n",
    "            merging_additions = cat_pad((to_merge_low[:-bs], to_merge_high[:-bs]), pad_axis=1, cat_axis=0)\n",
    "            new_merging_neighbours = cat_pad((base_merging_neighours, merging_additions), pad_axis=1, cat_axis=1)\n",
    "            new_neighbours = cat_pad((to_merge_low[-bs:], to_merge_high[-bs:]), pad_axis=1, cat_axis=1)\n",
    "            updated_neighbours = cat_pad((new_merging_neighbours, new_neighbours), pad_axis=1, cat_axis=0)\n",
    "            cluster_cache[(first_merge_idx, idx)] = update_id(np.concatenate((clusters[first_merge_idx], clusters[idx]), axis=0),updated_neighbours, dists)\n",
    "\n",
    "            combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_cache[(first_merge_idx, idx)][0]\n",
    "            id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = current_merge_id\n",
    "            # doesn't change id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]]\n",
    "\n",
    "            combined_ids[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "            id_estimates_one[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "            id_estimates_two[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "\n",
    "            cluster_sizes_two[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_sizes[idx] + cluster_sizes[first_merge_idx]\n",
    "            merge_cluster_lens[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_sizes[first_merge_idx]\n",
    "\n",
    "            cluster_sizes_one[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "            cluster_sizes_two[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "            merge_cluster_lens[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "\n",
    "    if print_times: print(\"Updating cache\", time.time()-start)\n",
    "    second = True\n",
    "\n",
    "to_save = {\n",
    "    \"clusters\": clusters,\n",
    "    \"merge_cluster_sizes\":merge_cluster_sizes,\n",
    "    \"num_merges\":merge_checker,\n",
    "    \"merges\":merges,\n",
    "    \"id_estimates_maxes\": id_estimates_maxes,\n",
    "    'id_estimates_mins': id_estimates_mins,\n",
    "    \"id_estimates_means\": id_estimates_means,\n",
    "    \"inter_id_vars\": inter_id_vars,\n",
    "    \"combined_ids_log\": combined_ids_log\n",
    "}\n",
    "save_pickle(f'{run_name}_final', to_save)\n",
    "\n",
    "print(f\"Final lenghts {[len(c) for c in clusters if len(c) > 0]}\")\n",
    "print(f\"Final ID variance: {id_variance([c for c in clusters if len(c) > 0], dists)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "db448fe85518e560c3ef83ccc8d105f9c373042f28551637af6f8c1ae46c95bb"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 ('two_step_zoo': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
