{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline \n",
    "import argparse\n",
    "import pprint\n",
    "import torch\n",
    "\n",
    "from config import get_cluster_configs, load_configs, parse_config_arg\n",
    "from two_step_zoo import get_clustering_module, get_loaders, get_clustering_trainer, get_evaluator, Writer, get_clusterer,get_id_estimator\n",
    "from two_step_zoo.datasets.loaders import get_loaders_from_config\n",
    "\n",
    "import pdb\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "import copy\n",
    "\n",
    "from sklearn.cluster import AgglomerativeClustering, OPTICS, MiniBatchKMeans\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "from heapq import heappush, heappop\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "from collections import defaultdict\n",
    "\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"cifar100\"\n",
    "run_name = \"all_classes_16_initial\"\n",
    "norm = 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gae_cfg, de_cfg, shared_cfg, cluster_cfg = get_cluster_configs(\n",
    "        dataset=dataset,\n",
    "        generalized_autoencoder=\"avb\",\n",
    "        density_estimator=\"vae\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = get_loaders_from_config(shared_cfg, \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tdata = train_loader.dataset.inputs.cpu()/norm\n",
    "tlabs = train_loader.dataset.targets.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vdata = valid_loader.dataset.inputs.cpu()/norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feats = tdata.reshape(tdata.shape[0],-1)\n",
    "m = 2\n",
    "n,f = feats.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dists = pdist(feats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dists = np.concatenate((dists,np.array((10000000,))), axis=0)\n",
    "dists.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_pickle(name, object):\n",
    "    with open(\"pickles/\"+name+\".pickle\", 'wb') as handle: \n",
    "        pickle.dump(object, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n",
    "def load_pickle(name):\n",
    "    with open(f'pickles/{name}.pickle', 'rb') as handle:\n",
    "        object = pickle.load(handle)\n",
    "    \n",
    "    return object\n",
    "\n",
    "# save_pickle(f'{dataset}_pdists_{run_name}.pickle', dists)\n",
    "\n",
    "dists = load_pickle(f'{dataset}_pdists')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dist_index(x,y): \n",
    "    if x == y: print(\"SAME\", x)\n",
    "    if x == -1 or y == -1: return -1\n",
    "    assert x != y\n",
    "    if x > y: x,y = y,x\n",
    "    return x*n + y - ((x + 2) * (x + 1)) // 2\n",
    "\n",
    "def get_dist(x,y): \n",
    "    return dists[dist_index(x,y)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Initial Clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_nn_dists(oidxs, k, dists):\n",
    "    nn_dists = np.array([[dists[dist_index(idx,j)] for j in oidxs if j != idx] for idx in oidxs])\n",
    "    nn_neighbours = np.array([[j for j in oidxs if j != idx] for idx in oidxs])\n",
    "\n",
    "    arg_part = np.argpartition(nn_dists, k)\n",
    "    nn_dists = np.take_along_axis(nn_dists, arg_part, axis=-1)[:,:k]\n",
    "    nn_neighbours = np.take_along_axis(nn_neighbours, arg_part, axis=-1)[:,:k]\n",
    "    \n",
    "    arg_sort = np.argsort(nn_dists, axis=-1)\n",
    "    return np.take_along_axis(nn_dists, arg_sort, axis=-1), np.take_along_axis(nn_neighbours, arg_sort, axis=-1)\n",
    "\n",
    "def calculate_id(idxs, second_idx=0, return_idx=False, k=10):\n",
    "    k = min(k, len(idxs)-2)\n",
    "    \n",
    "    nn_dists,nn_neighbours = get_nn_dists(idxs, k, dists)\n",
    "\n",
    "    d = np.log(nn_dists[:, k - 1: k] / nn_dists[:, 0:k - 1])\n",
    "    inv_mle = np.sum(d, -1) / (k-1)\n",
    "\n",
    "    if return_idx:\n",
    "        return (second_idx, (1 / inv_mle.mean()))\n",
    "    return (1 / inv_mle.mean()),nn_neighbours\n",
    "\n",
    "def id_variance(clusters):\n",
    "    ids = [calculate_id(cluster)[0] for cluster in clusters]\n",
    "    bs = len(ids)\n",
    "    mean_id = sum(ids) / len(ids)\n",
    "    return sum( [(mean_id-id)**2 for id in ids] ) / (bs-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = torch.unique(tlabs)\n",
    "class_to_ids = {cidx.item(): [] for cidx in classes}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx,tlab in enumerate(tlabs): class_to_ids[tlab.item()].append([idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_variance([list(itertools.chain.from_iterable(cidxs)) for cidxs in class_to_ids.values()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initial_clusters(idxs, num_merges=2):\n",
    "    for main_iter in range(num_merges):\n",
    "        next_idxs = []\n",
    "        distance_heap = []\n",
    "\n",
    "        used = set()\n",
    "        for i in (range(len(idxs)-1)):\n",
    "            for j in range(i+1, len(idxs)):\n",
    "\n",
    "                inner_dists = []\n",
    "                for first_idx in idxs[i]:\n",
    "                    for second_idx in idxs[j]:\n",
    "                        inner_dists.append(get_dist(i,j))\n",
    "                try:\n",
    "                    heappush(distance_heap, (sum(inner_dists) / len(inner_dists), i, j))\n",
    "                except:\n",
    "                    pdb.set_trace()\n",
    "\n",
    "        while(len(distance_heap) > 0):\n",
    "            _,i,j = heappop(distance_heap)\n",
    "            if i not in used and j not in used:\n",
    "                next_idxs.append(idxs[i] + idxs[j])\n",
    "                used.add(i)\n",
    "                used.add(j)\n",
    "            \n",
    "            if len(idxs) - len(used) < 3+main_iter:\n",
    "                leftover = [idxs[leftover_idx] for leftover_idx in range(len(idxs)) if leftover_idx not in used]\n",
    "                next_idxs.append(list(itertools.chain.from_iterable(leftover)))\n",
    "                break\n",
    "        \n",
    "        idxs = next_idxs\n",
    "    \n",
    "    return idxs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "og_clusters = [initial_clusters(idxs, num_merges=4) for idxs in class_to_ids.values()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "og_clusters = list(itertools.chain.from_iterable(og_clusters))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "og_clusters[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "og_lens = [len(c) for c in og_clusters]\n",
    "max(og_lens), min(og_lens), sum(og_lens) / len(og_lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_pickle(f'{dataset}_og_clusters_{run_name}', og_clusters)\n",
    "# og_clusters = load_pickle(f'{dataset}_og_clusters_{run_name}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Clustering Algo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_id(idxs, nn_neighbours, second_idx=0, return_idx=False, k=10):\n",
    "\n",
    "    nn_dists = np.array([[dists[dist_index(im_idx,j)] for j in nn_neighbours[idx]] for idx,im_idx in enumerate(idxs)])\n",
    "    \n",
    "    arg_sort = [np.unique(nn_dists[i], return_index=True) for i in range(nn_dists.shape[0])]\n",
    "    k = min(k,min([(nn[0] < 1e6).sum() for nn in arg_sort]))\n",
    "\n",
    "    arg_sort = np.stack([nn[1][:k] for nn in arg_sort])\n",
    "    nn_dists,nn_neighbours = np.take_along_axis(nn_dists, arg_sort, axis=-1), np.take_along_axis(nn_neighbours, arg_sort, axis=-1)\n",
    "\n",
    "    d = np.log(nn_dists[:, k - 1: k] / nn_dists[:, 0:k - 1])\n",
    "        \n",
    "    inv_mle = np.sum(d, -1) / (k-1)\n",
    "\n",
    "    if return_idx:\n",
    "        return (second_idx, (1 / inv_mle.mean()))\n",
    "    return (1 / inv_mle.mean()),nn_neighbours\n",
    "\n",
    "def cat_pad(tuple,cat_axis,pad_axis):\n",
    "    max_len = max([t.shape[pad_axis] for t in tuple])\n",
    "    if max_len == min([t.shape[pad_axis] for t in tuple]): return np.concatenate(tuple, axis=cat_axis)\n",
    "    return np.concatenate([np.pad(t, pad_width=((0,max_len-t.shape[pad_axis] if pad_axis == 0 else 0),\\\n",
    "        (0, max_len-t.shape[pad_axis] if pad_axis == 1 else 0)), \\\n",
    "        mode=\"constant\", constant_values=-1) for t in tuple],axis=cat_axis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_from_save = False\n",
    "load_legacy = False\n",
    "save_first = True\n",
    "print_times = False\n",
    "print_stats = False\n",
    "m = 7\n",
    "num_test = 30\n",
    "sections = 10000 # Used to save mem on ID calc\n",
    "save_name = \"first_iter_\" + dataset + \"_\" + run_name\n",
    "error = False\n",
    "is_test_loop = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if first_from_save:\n",
    "    print(f\"Loading initial iteration from pickles/{save_name}.pickle\")\n",
    "    b = load_pickle(save_name)\n",
    " \n",
    "    clusters = b[\"clusters\"]\n",
    "    id_estimates = b[\"id_estimates\"]\n",
    "    id_sum = b[\"id_sum\"]\n",
    "    cluster_cache = b[\"cluster_cache\"]\n",
    "    used_idxs = b[\"used_idxs\"]\n",
    "\n",
    "    if \"first_merge_idx\" in b:\n",
    "        first_merge_idx = b[\"first_merge_idx\"]\n",
    "    else:\n",
    "        first_merge_idx=0\n",
    "\n",
    "    if \"second_merge_idx\" in b:\n",
    "        second_merge_idx = b[\"second_merge_idx\"]\n",
    "    else:\n",
    "        second_merge_idx=1\n",
    "    \n",
    "    print(f\"Using {len(clusters)} clusters\")\n",
    "\n",
    "    combined_ids = []\n",
    "    id_estimates_one = []\n",
    "    id_estimates_two = []\n",
    "    idx_pairs = []\n",
    "    idx_pair_to_index = {}\n",
    "\n",
    "    if load_legacy:\n",
    "        for first_idx in (used_idxs):\n",
    "                    \n",
    "            for second_idx in used_idxs:\n",
    "\n",
    "                if second_idx <= first_idx: continue\n",
    "\n",
    "                idx_pair_to_index[(first_idx, second_idx)] = len(combined_ids)\n",
    "\n",
    "                combined_id = cluster_cache[(first_idx, second_idx)][0]\n",
    "\n",
    "                combined_ids.append(combined_id)\n",
    "                id_estimates_one.append(id_estimates[first_idx])\n",
    "                id_estimates_two.append(id_estimates[second_idx])\n",
    "                idx_pairs.append((first_idx, second_idx))\n",
    "        \n",
    "        combined_ids = torch.tensor(combined_ids, device=\"cuda\", dtype=torch.float32)\n",
    "        id_estimates_one = torch.tensor(id_estimates_one, device=\"cuda\", dtype=torch.float32)\n",
    "        id_estimates_two = torch.tensor(id_estimates_two, device=\"cuda\", dtype=torch.float32)\n",
    " \n",
    "else:\n",
    "  \n",
    "    clusters = copy.deepcopy(og_clusters)\n",
    "    clusters = [np.array(cluster) for cluster in clusters]\n",
    "    for cluster in clusters: assert len(cluster) > 3, \"All clusters must have length greater than 3\"\n",
    "\n",
    "    print(f\"Using {len(clusters)} clusters\")\n",
    "\n",
    "    id_estimates = [calculate_id(cluster)[0] for cluster in clusters]\n",
    "    id_sum = sum(id_estimates)\n",
    "    used_idxs = set([i for i in range(len(clusters))])\n",
    "    new_n = len(clusters)\n",
    "\n",
    "    cluster_cache = {}\n",
    "\n",
    "second=False\n",
    "switches=[]\n",
    "new_n = len(clusters)\n",
    "merge_cluster_sizes = []\n",
    "num_merges = []\n",
    "merge_checker = defaultdict(int)\n",
    "merges = []\n",
    "\n",
    "if not load_legacy:\n",
    "    combined_ids = []\n",
    "    id_estimates_one = []\n",
    "    id_estimates_two = []\n",
    "    idx_pairs = []\n",
    "    idx_pair_to_index = {}\n",
    "\n",
    "for main_iter_orig in tqdm(range(num_test if is_test_loop else new_n-m)):\n",
    "    main_iter = main_iter_orig\n",
    "    lc = len(used_idxs)\n",
    "\n",
    "    if not second and first_from_save:\n",
    "        current_merge = (first_merge_idx,second_merge_idx)\n",
    "        current_merge_id = cluster_cache[(first_merge_idx,second_merge_idx)][0]\n",
    "        current_max_var = 10 # not used\n",
    "    else:\n",
    "        current_max_var = -math.inf\n",
    "        current_merge = (0,0)\n",
    "        current_merge_id = math.inf\n",
    "\n",
    "    # Get optimal merge pair\n",
    "    start = time.time()\n",
    "    calcs = []\n",
    "    id_est_vec_start = time.time()\n",
    "    id_est_vec = torch.tensor([id_estimates[id] for id in used_idxs], device=\"cuda\", dtype=torch.float32)\n",
    "    if print_times: print(\"id_est_vec time:\", time.time()-id_est_vec_start)\n",
    "\n",
    "    if not first_from_save or second:\n",
    "        agg_start = time.time()\n",
    "\n",
    "        if main_iter == 0:\n",
    "            for first_idx in (used_idxs):\n",
    "                \n",
    "                for second_idx in used_idxs:\n",
    "\n",
    "                    if second_idx <= first_idx: continue\n",
    "\n",
    "                    # if main_iter == 0:\n",
    "                    first_cluster, second_cluster = clusters[first_idx], clusters[second_idx]\n",
    "                    combined_id,neighbours = calculate_id(np.concatenate((first_cluster,second_cluster)))\n",
    "                    cluster_cache[(first_idx, second_idx)] = (combined_id,neighbours)\n",
    "\n",
    "                    idx_pair_to_index[(first_idx, second_idx)] = len(combined_ids)\n",
    "\n",
    "                    combined_ids.append(combined_id)\n",
    "                    id_estimates_one.append(id_estimates[first_idx])\n",
    "                    id_estimates_two.append(id_estimates[second_idx])\n",
    "                    idx_pairs.append((first_idx, second_idx))\n",
    "                \n",
    "            \n",
    "            combined_ids = torch.tensor(combined_ids, device=\"cuda\", dtype=torch.float32)\n",
    "            id_estimates_one = torch.tensor(id_estimates_one, device=\"cuda\", dtype=torch.float32)\n",
    "            id_estimates_two = torch.tensor(id_estimates_two, device=\"cuda\", dtype=torch.float32)\n",
    "\n",
    "        if print_times: print(\"Aggregate ids_time time:\", time.time()-agg_start)\n",
    "        calc_start = time.time()\n",
    "        candidate_id_sums = id_sum - id_estimates_one - id_estimates_two + combined_ids\n",
    "        candidate_id_means = candidate_id_sums / (len(used_idxs)-1)\n",
    "\n",
    "        long_op = time.time()\n",
    "        num_entries = candidate_id_means.shape[0]\n",
    "        quotient = num_entries // sections\n",
    "        remainder = num_entries % sections\n",
    "        candidate_id_vars = []\n",
    "        for i in (range(sections)):\n",
    "            candidate_id_vars.append( ((id_est_vec[None,:]-candidate_id_means[i*quotient:(i+1)*quotient,None])**2).sum(axis=1) )\n",
    "        if remainder != 0:\n",
    "            candidate_id_vars.append( ((id_est_vec[None,:]-candidate_id_means[-remainder:,None])**2).sum(axis=1) )\n",
    "        candidate_id_vars = torch.cat(candidate_id_vars)\n",
    "        if print_times: print(\"Long op time:\", time.time()-long_op)\n",
    "\n",
    "        candidate_id_vars -= ( (candidate_id_means-id_estimates_one)**2 + (candidate_id_means-id_estimates_two)**2 )\n",
    "        candidate_id_vars += (candidate_id_means-combined_ids)**2\n",
    "        candidate_id_vars /= len(used_idxs)-1-1 # Sample variance\n",
    "        max_index = torch.argmax(torch.nan_to_num(candidate_id_vars,nan=0)).item()\n",
    "\n",
    "        current_max_var = candidate_id_vars[max_index].item()\n",
    "        current_merge = idx_pairs[max_index]\n",
    "        current_merge_id = combined_ids[max_index].item()\n",
    "        if print_times: print(\"Calc time:\", time.time()-calc_start)\n",
    "\n",
    "    if second == False and not first_from_save and save_first: \n",
    "        to_save = {\n",
    "            \"clusters\": clusters,\n",
    "            \"id_estimates\": id_estimates,\n",
    "            \"id_sum\": id_sum,\n",
    "            \"cluster_cache\": cluster_cache,\n",
    "            \"used_idxs\": used_idxs,\n",
    "            \"first_merge_idx\": current_merge[0],\n",
    "            \"second_merge_idx\": current_merge[1]\n",
    "        }\n",
    "        save_pickle(f'{save_name}.pickle', to_save)\n",
    "\n",
    "    if print_times: print(\"Main loop\", time.time()-start)\n",
    "    start = time.time()\n",
    "\n",
    "    first_merge_idx, second_merge_idx = current_merge\n",
    "    switches.append((first_merge_idx, second_merge_idx))\n",
    "    used_idxs.remove(second_merge_idx)\n",
    "\n",
    "    if print_stats:\n",
    "        id_ests=[]\n",
    "        for id in used_idxs:\n",
    "            id_ests.append((len(clusters[id]), id_estimates[id]))\n",
    "        plt.scatter([i[0] for i in id_ests], [i[1] for i in id_ests])\n",
    "        plt.scatter([len(clusters[first_merge_idx]), len(clusters[second_merge_idx])], [(id_estimates[first_merge_idx]), (id_estimates[second_merge_idx])], color=\"red\")\n",
    "        plt.show()\n",
    "\n",
    "    if print_stats: print(\"Merging\", first_merge_idx, \"into\", second_merge_idx, \"size1:\", len(clusters[first_merge_idx]), \"size2:\", len(clusters[second_merge_idx]), \"max var:\", current_max_var)\n",
    "\n",
    "    clusters[first_merge_idx] = np.concatenate((clusters[first_merge_idx],clusters[second_merge_idx]))\n",
    "    clusters[second_merge_idx] = []\n",
    "\n",
    "    merge_cluster_sizes.append(len(clusters[first_merge_idx]))\n",
    "    num_merges.append(max(merge_checker[first_merge_idx], merge_checker[second_merge_idx]))\n",
    "\n",
    "    merge_checker[first_merge_idx]= max(merge_checker[first_merge_idx], merge_checker[second_merge_idx]) + 1\n",
    "\n",
    "    merges.append((first_merge_idx, second_merge_idx))\n",
    "    id_sum -= (id_estimates[first_merge_idx] + id_estimates[second_merge_idx])\n",
    "    id_sum += current_merge_id\n",
    "\n",
    "    id_estimates[first_merge_idx] = current_merge_id\n",
    "\n",
    "    base_merging_neighours = cluster_cache[(first_merge_idx, second_merge_idx)][1]\n",
    "    if print_times: print(\"Initial merge\", time.time()-start)\n",
    "\n",
    "    combined_ids[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    id_estimates_one[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "    id_estimates_two[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan\n",
    "\n",
    "    start = time.time()\n",
    "    for idx in used_idxs:\n",
    "        if idx == first_merge_idx: continue\n",
    "        \n",
    "        bs = clusters[idx].shape[0]\n",
    "\n",
    "        start_update_id_time = time.time()\n",
    "        if idx < first_merge_idx:\n",
    "            to_merge_low = cluster_cache[(idx, first_merge_idx)][1]\n",
    "            to_merge_high = cluster_cache[(idx, second_merge_idx)][1]\n",
    "\n",
    "            merging_additions = cat_pad((to_merge_low[bs:], to_merge_high[bs:]), pad_axis=1, cat_axis=0)\n",
    "            new_merging_neighbours = cat_pad((base_merging_neighours, merging_additions), pad_axis=1, cat_axis=1)\n",
    "            new_neighbours = cat_pad((to_merge_low[:bs], to_merge_high[:bs]), pad_axis=1, cat_axis=1)\n",
    "            updated_neighbours = cat_pad((new_neighbours, new_merging_neighbours), pad_axis=1, cat_axis=0)\n",
    "            \n",
    "            cluster_cache[(idx, first_merge_idx)] = update_id(np.concatenate((clusters[idx], clusters[first_merge_idx]), axis=0),updated_neighbours)\n",
    "         \n",
    "            combined_ids[idx_pair_to_index[(idx, first_merge_idx)]] = cluster_cache[(idx, first_merge_idx)][0]\n",
    "            # doesn't change id_estimates_one[idx_pair_to_index[(idx, first_merge_idx)]]\n",
    "            id_estimates_two[idx_pair_to_index[(idx, first_merge_idx)]] = current_merge_id\n",
    "\n",
    "            combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "\n",
    "        elif idx > first_merge_idx and idx < second_merge_idx:\n",
    "        \n",
    "            to_merge_low = cluster_cache[(first_merge_idx, idx)][1]\n",
    "            to_merge_high = cluster_cache[(idx, second_merge_idx)][1]\n",
    "\n",
    "            merging_additions = cat_pad((to_merge_low[:-bs], to_merge_high[bs:]), pad_axis=1, cat_axis=0)\n",
    "            new_merging_neighbours = cat_pad((base_merging_neighours, merging_additions), pad_axis=1, cat_axis=1)\n",
    "            new_neighbours = cat_pad((to_merge_low[-bs:], to_merge_high[:bs]), pad_axis=1, cat_axis=1)\n",
    "            updated_neighbours = cat_pad((new_merging_neighbours, new_neighbours), pad_axis=1, cat_axis=0)\n",
    "            \n",
    "            cluster_cache[(first_merge_idx, idx)] = update_id(np.concatenate((clusters[first_merge_idx], clusters[idx]), axis=0),updated_neighbours)\n",
    "            \n",
    "            combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_cache[(first_merge_idx, idx)][0]\n",
    "            id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = current_merge_id\n",
    "            # doesn't change id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]] = current_merge_id\n",
    "\n",
    "            combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan\n",
    "            \n",
    "        else:\n",
    "            to_merge_low = cluster_cache[(first_merge_idx, idx)][1]\n",
    "            to_merge_high = cluster_cache[(second_merge_idx, idx)][1]\n",
    "\n",
    "            merging_additions = cat_pad((to_merge_low[:-bs], to_merge_high[:-bs]), pad_axis=1, cat_axis=0)\n",
    "            new_merging_neighbours = cat_pad((base_merging_neighours, merging_additions), pad_axis=1, cat_axis=1)\n",
    "            new_neighbours = cat_pad((to_merge_low[-bs:], to_merge_high[-bs:]), pad_axis=1, cat_axis=1)\n",
    "            updated_neighbours = cat_pad((new_merging_neighbours, new_neighbours), pad_axis=1, cat_axis=0)\n",
    "            cluster_cache[(first_merge_idx, idx)] = update_id(np.concatenate((clusters[first_merge_idx], clusters[idx]), axis=0),updated_neighbours)\n",
    "\n",
    "            combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_cache[(first_merge_idx, idx)][0]\n",
    "            id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = current_merge_id\n",
    "            # doesn't change id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]]\n",
    "\n",
    "            combined_ids[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "            id_estimates_one[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "            id_estimates_two[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan\n",
    "\n",
    "    if print_times: print(\"Updating cache\", time.time()-start)\n",
    "    if error: break\n",
    "    second = True\n",
    "\n",
    "to_save = {\n",
    "    \"clusters\": clusters\n",
    "}\n",
    "save_pickle(f'{save_name}_final', to_save)\n",
    "clusters = [c for c in clusters if len(c) > 0]\n",
    "id_variance(clusters), switches #, [len(c) for c in clusters]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_spread = 0\n",
    "min_spread = math.inf\n",
    "min_c, max_c = [], []\n",
    "switches_max = []\n",
    "\n",
    "for _ in tqdm(range(1000)):\n",
    "    baseline_clusters = copy.deepcopy(og_clusters)\n",
    "    baseline_clusters = [np.array(c) for c in baseline_clusters]\n",
    "    baseline_used_idxs = set([i for i in range(len(baseline_clusters))])\n",
    "\n",
    "    # print([len(x) for x in baseline_clusters])\n",
    "    switches = []\n",
    "    # to_do = [(19,37), (12,34)]\n",
    "    # for idx in range(num_test if \n",
    "    \n",
    "    # is_test_loop else new_n-m):\n",
    "    while len(baseline_used_idxs) > m:\n",
    "        x,y = random.sample(baseline_used_idxs, 2)\n",
    "        # x,y = to_do[idx]\n",
    "        if x > y: x,y = y,x\n",
    "        # print(baseline_clusters,x,y)\n",
    "        baseline_clusters[x] = np.concatenate((baseline_clusters[x],baseline_clusters[y]))\n",
    "        baseline_clusters[y] = []\n",
    "        baseline_used_idxs.remove(y)\n",
    "        switches.append((x,y))\n",
    "    # print([len(x) for x in baseline_clusters])\n",
    "    baseline_clusters = [b for b in baseline_clusters if len(b) > 0]\n",
    "    spread = id_variance(baseline_clusters)\n",
    "    if spread < min_spread: min_c = baseline_clusters\n",
    "    if spread > max_spread: \n",
    "        max_c = baseline_clusters\n",
    "        switches_max = switches\n",
    "    max_spread = max(spread, max_spread)\n",
    "    min_spread = min(spread, min_spread)\n",
    "max_spread, min_spread,max_c,switches_max #[len(x) for x in max_c], [len(x) for x in min_c] "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "db448fe85518e560c3ef83ccc8d105f9c373042f28551637af6f8c1ae46c95bb"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 ('two_step_zoo': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
