import argparse
import pprint
import torch

from config import get_cluster_configs, load_configs, parse_config_arg
from two_step_zoo import get_clustering_module, get_loaders, get_clustering_trainer, get_evaluator, Writer, get_clusterer,get_id_estimator
from two_step_zoo.datasets.loaders import get_loaders_from_config

import pdb
import time
import matplotlib.pyplot as plt

import math
import copy

from sklearn.cluster import AgglomerativeClustering, OPTICS, MiniBatchKMeans
import numpy as np
from tqdm import tqdm
import random

from scipy.spatial.distance import pdist

from heapq import heappush, heappop
from collections import defaultdict
import pickle
from collections import defaultdict

import itertools
import os

def pickle_exists(name): 
    return os.path.exists(f'pickles/{name}.pickle')

def save_pickle(name, object):
    with open(f'pickles/{name}.pickle', 'wb') as handle: 
        pickle.dump(object, handle, protocol=pickle.HIGHEST_PROTOCOL)

def load_pickle(name):
    with open(f'pickles/{name}.pickle', 'rb') as handle:
        object = pickle.load(handle)
    
    return object

def dist_index(x,y): 
    if x == y: print("SAME", x)
    if x == -1 or y == -1: return -1
    assert x != y
    if x > y: x,y = y,x
    return x*n + y - ((x + 2) * (x + 1)) // 2

def get_dist(x,y, dists): 
    return dists[dist_index(x,y)]

def get_nn_dists(oidxs, k, dists):
    nn_dists = np.array([[dists[dist_index(idx,j)] for j in oidxs if j != idx] for idx in oidxs])
    nn_neighbours = np.array([[j for j in oidxs if j != idx] for idx in oidxs])

    arg_part = np.argpartition(nn_dists, k)
    nn_dists = np.take_along_axis(nn_dists, arg_part, axis=-1)[:,:k]
    nn_neighbours = np.take_along_axis(nn_neighbours, arg_part, axis=-1)[:,:k]
    
    arg_sort = np.argsort(nn_dists, axis=-1)
    return np.take_along_axis(nn_dists, arg_sort, axis=-1), np.take_along_axis(nn_neighbours, arg_sort, axis=-1)

def calculate_id(idxs, dists, second_idx=0, return_idx=False, k=10):
    k = min(k, len(idxs)-2)
    
    nn_dists,nn_neighbours = get_nn_dists(idxs, k, dists)

    d = np.log(nn_dists[:, k - 1: k] / nn_dists[:, 0:k - 1])
    inv_mle = np.sum(d, -1) / (k-1)

    if return_idx:
        return (second_idx, (1 / inv_mle.mean()))
    return (1 / inv_mle.mean()),nn_neighbours

def id_variance(clusters, dists):
    ids = [calculate_id(cluster, dists)[0] for cluster in clusters]
    bs = len(ids)
    mean_id = sum(ids) / len(ids)
    return sum( [(mean_id-id)**2 for id in ids] ) / (bs-1)

def update_id(idxs, nn_neighbours, dists, second_idx=0, return_idx=False, k=10):

    nn_dists = np.array([[dists[dist_index(im_idx,j)] for j in nn_neighbours[idx]] for idx,im_idx in enumerate(idxs)])
    
    arg_sort = [np.unique(nn_dists[i], return_index=True) for i in range(nn_dists.shape[0])]
    k = min(k,min([(nn[0] < 1e6).sum() for nn in arg_sort]))

    arg_sort = np.stack([nn[1][:k] for nn in arg_sort])
    nn_dists,nn_neighbours = np.take_along_axis(nn_dists, arg_sort, axis=-1), np.take_along_axis(nn_neighbours, arg_sort, axis=-1)

    d = np.log(nn_dists[:, k - 1: k] / nn_dists[:, 0:k - 1])
        
    inv_mle = np.sum(d, -1) / (k-1)

    if return_idx:
        return (second_idx, (1 / inv_mle.mean()))
    return (1 / inv_mle.mean()),nn_neighbours

def cat_pad(tuple,cat_axis,pad_axis):
    max_len = max([t.shape[pad_axis] for t in tuple])
    if max_len == min([t.shape[pad_axis] for t in tuple]): return np.concatenate(tuple, axis=cat_axis)
    return np.concatenate([np.pad(t, pad_width=((0,max_len-t.shape[pad_axis] if pad_axis == 0 else 0),\
        (0, max_len-t.shape[pad_axis] if pad_axis == 1 else 0)), \
        mode="constant", constant_values=-1) for t in tuple],axis=cat_axis)

def initial_clusters(idxs, num_merges=2):
    for main_iter in tqdm(range(num_merges)):
        next_idxs = []
        distance_heap = []

        used = set()
        for i in (range(len(idxs)-1)):
            for j in range(i+1, len(idxs)):

                inner_dists = []
                for first_idx in idxs[i]:
                    for second_idx in idxs[j]:
                        inner_dists.append(get_dist(i,j, dists))
                try:
                    heappush(distance_heap, (sum(inner_dists) / len(inner_dists), i, j))
                except:
                    pdb.set_trace()

        while(len(distance_heap) > 0):
            _,i,j = heappop(distance_heap)
            if i not in used and j not in used:
                next_idxs.append(idxs[i] + idxs[j])
                used.add(i)
                used.add(j)
            
            if len(idxs) - len(used) <= 3:
                leftover = [idxs[leftover_idx] for leftover_idx in range(len(idxs)) if leftover_idx not in used]
                next_idxs.append(list(itertools.chain.from_iterable(leftover)))
                break
        
        idxs = next_idxs
   
    return idxs


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Get clusters based on maximizing inter cluster ID variance')
    parser.add_argument('--dataset',type=str, help='dataset.')
    parser.add_argument('--m',type=int, help='number of clusters', default=5)
    parser.add_argument('--cap',type=int, help='max cluster size', default=-1)
    parser.add_argument('--save_graph_iter',type=int, default=1)
    parser.add_argument('--num_initial_merges',type=int, help='number of clusters', default=4)
    parser.add_argument('--run_name',type=str,default="")
    parser.add_argument('--norm',type=float, default=255.)
    parser.add_argument('--class_prior', action='store_true')
    parser.add_argument('--first_from_save',action='store_true')
    parser.add_argument('--print_times', action='store_true')
    parser.add_argument('--print_stats', action='store_true')
    parser.add_argument('--save_plots', action='store_false')
    parser.add_argument('--save_first', action='store_false')

    args = parser.parse_args()

    args.run_name += f"_{args.dataset}_{args.m}_{args.num_initial_merges}"

    if args.class_prior:
        args.run_name += "_classprior"

    gae_cfg, de_cfg, shared_cfg, cluster_cfg = get_cluster_configs(
        dataset=args.dataset,
        generalized_autoencoder="avb",
        density_estimator="vae"
    )

    train_loader, valid_loader, test_loader = get_loaders_from_config(shared_cfg, "cpu")

    tdata = train_loader.dataset.inputs.cpu()/args.norm
    tlabs = train_loader.dataset.targets.cpu()

    feats = tdata.reshape(tdata.shape[0],-1)
    n,f = feats.shape

    # Calculate pairwise distances
    if pickle_exists(f'{args.dataset}_pdists'):
        print(f"Loading pdists from {args.dataset}_pdists")
        dists = load_pickle(f'{args.dataset}_pdists')
    else:
        print(f"Calculating pdists")
        dists = pdist(feats)
        print(f"Saving pdists as {args.dataset}_pdists")
        save_pickle(f'{args.dataset}_pdists', dists)
    
    if dists.min() < 1e-4:
        print("Tiny pdists, adding an epsilon")
        dists = dists + 1e-4

    # Calculate initializations
    if pickle_exists(f'og_clusters_{args.run_name}'):

        print("Loading initial clusters from", f'og_clusters_{args.run_name}')
        og_clusters = load_pickle(f'og_clusters_{args.run_name}')
    
    else:

        print("Calculating initial clusters...")

        classes = torch.unique(tlabs)
        
        if args.class_prior:
            class_to_ids = {cidx.item(): [] for cidx in classes}

            for idx,tlab in enumerate(tlabs): class_to_ids[tlab.item()].append([idx])

            print(f"Id variance of class clusters: \
                {id_variance([list(itertools.chain.from_iterable(cidxs)) for cidxs in class_to_ids.values()], dists)}")
            
            og_clusters = [initial_clusters(idxs, num_merges=args.num_initial_merges) for idxs in class_to_ids.values()]
            og_clusters = list(itertools.chain.from_iterable(og_clusters))
        else:
            og_clusters = initial_clusters([[i] for i in range(tlabs.shape[0])], num_merges=args.num_initial_merges)

        og_lens = [len(c) for c in og_clusters]
        print("Max, min & mean cluster initial sizes",\
            max(og_lens), min(og_lens), sum(og_lens) / len(og_lens))
        
        print("Saving initial clusters to", f'og_clusters_{args.run_name}')
        save_pickle(f'og_clusters_{args.run_name}', og_clusters)
    
    # Main clustering algo
    sections = 10000

    if args.first_from_save:
        print(f"Loading initial iteration from pickles/first_iter_{args.run_name}.pickle")
        b = load_pickle(f"first_iter_{args.run_name}")
    
        clusters = b["clusters"]
        id_estimates = b["id_estimates"]
        id_sum = b["id_sum"]
        cluster_cache = b["cluster_cache"]
        used_idxs = b["used_idxs"]
        combined_ids = b["combined_ids"].cuda()
        id_estimates_one = b["id_estimates_one"].cuda()
        id_estimates_two = b["id_estimates_two"].cuda()
        idx_pair_to_index = b["idx_pair_to_index"]
        idx_pairs = b["idx_pairs"]

        if "first_merge_idx" in b:
            first_merge_idx = b["first_merge_idx"]
        else:
            first_merge_idx=0

        if "second_merge_idx" in b:
            second_merge_idx = b["second_merge_idx"]
        else:
            second_merge_idx=1
        
        print(f"Using {len(clusters)} clusters")
    
    else:
    
        clusters = copy.deepcopy(og_clusters)
        clusters = [np.array(cluster) for cluster in clusters]
        for cluster in clusters: assert len(cluster) > 3, "All clusters must have length greater than 3"

        print(f"Using {len(clusters)} clusters")

        id_estimates = [calculate_id(cluster, dists)[0] for cluster in clusters]
        print("Initial ID estimates min, max, mean:", min(id_estimates), max(id_estimates), sum(id_estimates)/len(id_estimates))

        id_sum = sum(id_estimates)
        used_idxs = set([i for i in range(len(clusters))])
        new_n = len(clusters)

        cluster_cache = {}

    second=False
    switches=[]
    new_n = len(clusters)
    merge_cluster_sizes = []
    num_merges = []
    merge_checker = defaultdict(int)
    merges = []
    keep_in_calc_idxs = []

    id_estimates_maxes = []
    id_estimates_mins = []
    id_estimates_means = []
    inter_id_vars = []
    combined_ids_log = []

    save_set = [i for i in range(15)] + [20, 25, 30]

    if not args.first_from_save:
        combined_ids = []
        id_estimates_one = []
        id_estimates_two = []
        idx_pairs = []
        idx_pair_to_index = {}

    for main_iter in tqdm(range(new_n-args.m)):
        if len(clusters) == args.m: break

        lc = len(used_idxs)

        if not second and args.first_from_save:
            current_merge = (first_merge_idx,second_merge_idx)
            current_merge_id = cluster_cache[(first_merge_idx,second_merge_idx)][0]
            current_max_var = 10 # not used
        else:
            current_max_var = -math.inf
            current_merge = (0,0)
            current_merge_id = math.inf

        # Get optimal merge pair
        start = time.time()
        calcs = []
        id_est_vec_start = time.time()
        id_est_vec = torch.tensor([id_estimates[id] for id in used_idxs], device="cuda", dtype=torch.float32)
        if len(keep_in_calc_idxs) > 0:
            id_est_vec_kept_out = torch.tensor([id_estimates[id] for id in keep_in_calc_idxs], device="cuda", dtype=torch.float32)
            id_est_vec = torch.cat([id_est_vec, id_est_vec_kept_out])
        if args.print_times: print("id_est_vec time:", time.time()-id_est_vec_start)

        if not args.first_from_save or second:
            agg_start = time.time()

            if main_iter == 0:
                for first_idx in (used_idxs):
                    
                    for second_idx in used_idxs:

                        if second_idx <= first_idx: continue

                        first_cluster, second_cluster = clusters[first_idx], clusters[second_idx]
                        combined_id,neighbours = calculate_id(np.concatenate((first_cluster,second_cluster)),dists)
                        cluster_cache[(first_idx, second_idx)] = (combined_id,neighbours)

                        idx_pair_to_index[(first_idx, second_idx)] = len(combined_ids)

                        combined_ids.append(combined_id)
                        id_estimates_one.append(id_estimates[first_idx])
                        id_estimates_two.append(id_estimates[second_idx])
                        idx_pairs.append((first_idx, second_idx))
                    
                
                combined_ids = torch.tensor(combined_ids, device="cuda", dtype=torch.float32)
                id_estimates_one = torch.tensor(id_estimates_one, device="cuda", dtype=torch.float32)
                id_estimates_two = torch.tensor(id_estimates_two, device="cuda", dtype=torch.float32)

            if args.print_times: print("Aggregate ids_time time:", time.time()-agg_start)
            calc_start = time.time()
            candidate_id_sums = id_sum - id_estimates_one - id_estimates_two + combined_ids
            candidate_id_means = candidate_id_sums / (len(used_idxs)-1)

            long_op = time.time()
            num_entries = candidate_id_means.shape[0]
            quotient = num_entries // sections
            remainder = num_entries % sections
            candidate_id_vars = []
            for i in (range(sections)):
                candidate_id_vars.append( ((id_est_vec[None,:]-candidate_id_means[i*quotient:(i+1)*quotient,None])**2).sum(axis=1) )
            if remainder != 0:
                candidate_id_vars.append( ((id_est_vec[None,:]-candidate_id_means[-remainder:,None])**2).sum(axis=1) )
            candidate_id_vars = torch.cat(candidate_id_vars)
            if args.print_times: print("Long op time:", time.time()-long_op)

            candidate_id_vars -= ( (candidate_id_means-id_estimates_one)**2 + (candidate_id_means-id_estimates_two)**2 )
            candidate_id_vars += (candidate_id_means-combined_ids)**2
            candidate_id_vars /= len(used_idxs)-1-1 # Sample variance
            max_index = torch.argmax(torch.nan_to_num(candidate_id_vars,nan=0)).item()

            current_max_var = candidate_id_vars[max_index].item()
            current_merge = idx_pairs[max_index]
            current_merge_id = combined_ids[max_index].item()
            if args.print_times: print("Calc time:", time.time()-calc_start)

        if (second == False and not  args.first_from_save and args.save_first) or len(used_idxs) in save_set: 
            to_save = {
                "clusters": clusters,
                "id_estimates": id_estimates,
                "id_sum": id_sum,
                "cluster_cache": cluster_cache,
                "used_idxs": used_idxs,
                "first_merge_idx": current_merge[0],
                "second_merge_idx": current_merge[1],
                "combined_ids": combined_ids.cpu(),
                "id_estimates_one": id_estimates_one.cpu(),
                "id_estimates_two": id_estimates_two.cpu(),
                "idx_pair_to_index": idx_pair_to_index,
                "idx_pairs": idx_pairs
            }
            if second:
                save_pickle(f'iter_{main_iter}_{args.run_name}.pickle', to_save)
            else:
                save_pickle(f'first_iter_{args.run_name}.pickle', to_save)

        if args.print_times: print("Main loop", time.time()-start)

        if torch.nansum(candidate_id_vars) == 0:
            print("No more merge candidates")
            break


        start = time.time()

        first_merge_idx, second_merge_idx = current_merge
        switches.append((first_merge_idx, second_merge_idx))
        used_idxs.remove(second_merge_idx)

        if args.print_stats: print("Merging", first_merge_idx, "into", second_merge_idx, "size1:", len(clusters[first_merge_idx]), "size2:", len(clusters[second_merge_idx]), "max var:", current_max_var)
        
        if args.save_plots and main_iter % args.save_graph_iter == 0:
            id_ests=[]
            for id in used_idxs:
                id_ests.append((len(clusters[id]), id_estimates[id]))
            plt.scatter([i[0] for i in id_ests], [i[1] for i in id_ests])
            plt.scatter([len(clusters[first_merge_idx]), len(clusters[second_merge_idx])], [(id_estimates[first_merge_idx]), (id_estimates[second_merge_idx])], color="red")
            plt.savefig(f"./id_run_saves/{args.run_name}_{main_iter}.png")
            plt.close()
        
        # Save stats
        id_estimates_non_empty = [id_estimates[id] for id in used_idxs]

        if min(id_estimates_non_empty) < 1e-3:
            print("Really low id estimate:", min(id_estimates_non_empty))
        
        if max(id_estimates_non_empty) > 100:
            print("Really high id estimate:", max(id_estimates_non_empty))

        id_estimates_maxes.append(max(id_estimates_non_empty))
        id_estimates_mins.append(min(id_estimates_non_empty))
        id_estimates_means.append(sum(id_estimates_non_empty)/len(id_estimates_non_empty))
        inter_id_vars.append(current_max_var)
        combined_ids_log.append(combined_id)

        clusters[first_merge_idx] = np.concatenate((clusters[first_merge_idx],clusters[second_merge_idx]))
        clusters[second_merge_idx] = []

        merge_cluster_sizes.append(len(clusters[first_merge_idx]))
        num_merges.append(max(merge_checker[first_merge_idx], merge_checker[second_merge_idx]))

        merge_checker[first_merge_idx]= max(merge_checker[first_merge_idx], merge_checker[second_merge_idx]) + 1

        merges.append((first_merge_idx, second_merge_idx))
        id_sum -= (id_estimates[first_merge_idx] + id_estimates[second_merge_idx])
        id_sum += current_merge_id

        id_estimates[first_merge_idx] = current_merge_id

        base_merging_neighours = cluster_cache[(first_merge_idx, second_merge_idx)][1]
        if args.print_times: print("Initial merge", time.time()-start)

        combined_ids[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan
        id_estimates_one[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan
        id_estimates_two[idx_pair_to_index[(first_merge_idx, second_merge_idx)]] = torch.nan

        if args.cap != -1 and clusters[first_merge_idx].shape[0] >= args.cap:
            print("Removing cluster of size:", clusters[first_merge_idx].shape[0])
            used_idxs.remove(first_merge_idx)
            keep_in_calc_idxs.append(first_merge_idx)

        start = time.time()
        for idx in used_idxs:
            if idx == first_merge_idx: continue

            if args.cap != -1 and clusters[first_merge_idx].shape[0] >= args.cap:
                if idx < first_merge_idx:

                    combined_ids[idx_pair_to_index[(idx, first_merge_idx)]] = torch.nan
                    id_estimates_one[idx_pair_to_index[(idx, first_merge_idx)]] = torch.nan
                    id_estimates_two[idx_pair_to_index[(idx, first_merge_idx)]] = torch.nan

                    combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                    id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                    id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                
                elif idx > first_merge_idx and idx < second_merge_idx:

                    combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan
                    id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan
                    id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan

                    combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                    id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                    id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                
                else:

                    combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan
                    id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan
                    id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]] = torch.nan

                    combined_ids[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan
                    id_estimates_one[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan
                    id_estimates_two[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan
            
            bs = clusters[idx].shape[0]

            start_update_id_time = time.time()
            if idx < first_merge_idx:
                to_merge_low = cluster_cache[(idx, first_merge_idx)][1]
                to_merge_high = cluster_cache[(idx, second_merge_idx)][1]

                merging_additions = cat_pad((to_merge_low[bs:], to_merge_high[bs:]), pad_axis=1, cat_axis=0)
                new_merging_neighbours = cat_pad((base_merging_neighours, merging_additions), pad_axis=1, cat_axis=1)
                new_neighbours = cat_pad((to_merge_low[:bs], to_merge_high[:bs]), pad_axis=1, cat_axis=1)
                updated_neighbours = cat_pad((new_neighbours, new_merging_neighbours), pad_axis=1, cat_axis=0)
                
                cluster_cache[(idx, first_merge_idx)] = update_id(np.concatenate((clusters[idx], clusters[first_merge_idx]), axis=0),updated_neighbours, dists)
            
                combined_ids[idx_pair_to_index[(idx, first_merge_idx)]] = cluster_cache[(idx, first_merge_idx)][0]
                # doesn't change id_estimates_one[idx_pair_to_index[(idx, first_merge_idx)]]
                id_estimates_two[idx_pair_to_index[(idx, first_merge_idx)]] = current_merge_id

                combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan

            elif idx > first_merge_idx and idx < second_merge_idx:
            
                to_merge_low = cluster_cache[(first_merge_idx, idx)][1]
                to_merge_high = cluster_cache[(idx, second_merge_idx)][1]

                merging_additions = cat_pad((to_merge_low[:-bs], to_merge_high[bs:]), pad_axis=1, cat_axis=0)
                new_merging_neighbours = cat_pad((base_merging_neighours, merging_additions), pad_axis=1, cat_axis=1)
                new_neighbours = cat_pad((to_merge_low[-bs:], to_merge_high[:bs]), pad_axis=1, cat_axis=1)
                updated_neighbours = cat_pad((new_merging_neighbours, new_neighbours), pad_axis=1, cat_axis=0)
                
                cluster_cache[(first_merge_idx, idx)] = update_id(np.concatenate((clusters[first_merge_idx], clusters[idx]), axis=0),updated_neighbours, dists)
                
                combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_cache[(first_merge_idx, idx)][0]
                id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = current_merge_id
                # doesn't change id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]] = current_merge_id

                combined_ids[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                id_estimates_one[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                id_estimates_two[idx_pair_to_index[(idx, second_merge_idx)]] = torch.nan
                
            else:
                to_merge_low = cluster_cache[(first_merge_idx, idx)][1]
                to_merge_high = cluster_cache[(second_merge_idx, idx)][1]

                merging_additions = cat_pad((to_merge_low[:-bs], to_merge_high[:-bs]), pad_axis=1, cat_axis=0)
                new_merging_neighbours = cat_pad((base_merging_neighours, merging_additions), pad_axis=1, cat_axis=1)
                new_neighbours = cat_pad((to_merge_low[-bs:], to_merge_high[-bs:]), pad_axis=1, cat_axis=1)
                updated_neighbours = cat_pad((new_merging_neighbours, new_neighbours), pad_axis=1, cat_axis=0)
                cluster_cache[(first_merge_idx, idx)] = update_id(np.concatenate((clusters[first_merge_idx], clusters[idx]), axis=0),updated_neighbours, dists)

                combined_ids[idx_pair_to_index[(first_merge_idx, idx)]] = cluster_cache[(first_merge_idx, idx)][0]
                id_estimates_one[idx_pair_to_index[(first_merge_idx, idx)]] = current_merge_id
                # doesn't change id_estimates_two[idx_pair_to_index[(first_merge_idx, idx)]]

                combined_ids[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan
                id_estimates_one[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan
                id_estimates_two[idx_pair_to_index[(second_merge_idx, idx)]] = torch.nan

        if args.print_times: print("Updating cache", time.time()-start)
        second = True

    to_save = {
        "clusters": clusters,
        "merge_cluster_sizes":merge_cluster_sizes,
        "num_merges":merge_checker,
        "merges":merges,
        "id_estimates_maxes": id_estimates_maxes,
        'id_estimates_mins': id_estimates_mins,
        "id_estimates_means": id_estimates_means,
        "inter_id_vars": inter_id_vars,
        "combined_ids_log": combined_ids_log
    }
    save_pickle(f'{args.run_name}_final', to_save)

    print(f"Final lenghts {[len(c) for c in clusters if len(c) > 0]}")
    print(f"Final ID variance: {id_variance([c for c in clusters if len(c) > 0], dists)}")