 
import pdb
import numpy as np
from heapq import heappush, heappop
import pickle
import itertools
import os

n = 50000

def pickle_exists(name): 
    return os.path.exists(f'pickles/{name}.pickle')

def save_pickle(name, object):
    with open(f'pickles/{name}.pickle', 'wb') as handle: 
        pickle.dump(object, handle, protocol=pickle.HIGHEST_PROTOCOL)

def load_pickle(name):
    with open(f'pickles/{name}.pickle', 'rb') as handle:
        object = pickle.load(handle)
    
    return object

def dist_index(x,y): 
    if x == y: print("SAME", x)
    if x == -1 or y == -1: return -1
    assert x != y
    if x > y: x,y = y,x
    return x*n + y - ((x + 2) * (x + 1)) // 2

def get_dist(x,y, dists): 
    return dists[dist_index(x,y)]

def get_nn_dists(oidxs, k, dists):
    nn_dists = np.array([[dists[dist_index(idx,j)] for j in oidxs if j != idx] for idx in oidxs])
    nn_neighbours = np.array([[j for j in oidxs if j != idx] for idx in oidxs])

    arg_part = np.argpartition(nn_dists, k)
    nn_dists = np.take_along_axis(nn_dists, arg_part, axis=-1)[:,:k]
    nn_neighbours = np.take_along_axis(nn_neighbours, arg_part, axis=-1)[:,:k]
    
    arg_sort = np.argsort(nn_dists, axis=-1)
    return np.take_along_axis(nn_dists, arg_sort, axis=-1), np.take_along_axis(nn_neighbours, arg_sort, axis=-1)

def calculate_id(idxs, dists, second_idx=0, return_idx=False, k=10):
    k = min(k, len(idxs)-2)
    
    nn_dists,nn_neighbours = get_nn_dists(idxs, k, dists)

    d = np.log(nn_dists[:, k - 1: k] / nn_dists[:, 0:k - 1])
    inv_mle = np.sum(d, -1) / (k-1)

    if return_idx:
        return (second_idx, (1 / inv_mle.mean()))
    return (1 / inv_mle.mean()),nn_neighbours

def id_variance(clusters, dists):
    ids = [calculate_id(cluster, dists)[0] for cluster in clusters]
    bs = len(ids)
    mean_id = sum(ids) / len(ids)
    return sum( [(mean_id-id)**2 for id in ids] ) / (bs-1)

def update_id(idxs, nn_neighbours, dists, second_idx=0, return_idx=False, k=10):

    nn_dists = np.array([[dists[dist_index(im_idx,j)] for j in nn_neighbours[idx]] for idx,im_idx in enumerate(idxs)])
    
    arg_sort = [np.unique(nn_dists[i], return_index=True) for i in range(nn_dists.shape[0])]
    k = min(k,min([(nn[0] < 1e6).sum() for nn in arg_sort]))

    arg_sort = np.stack([nn[1][:k] for nn in arg_sort])
    nn_dists,nn_neighbours = np.take_along_axis(nn_dists, arg_sort, axis=-1), np.take_along_axis(nn_neighbours, arg_sort, axis=-1)

    d = np.log(nn_dists[:, k - 1: k] / nn_dists[:, 0:k - 1])
        
    inv_mle = np.sum(d, -1) / (k-1)

    if return_idx:
        return (second_idx, (1 / inv_mle.mean()))
    return (1 / inv_mle.mean()),nn_neighbours

def cat_pad(tuple,cat_axis,pad_axis):
    max_len = max([t.shape[pad_axis] for t in tuple])
    if max_len == min([t.shape[pad_axis] for t in tuple]): return np.concatenate(tuple, axis=cat_axis)
    return np.concatenate([np.pad(t, pad_width=((0,max_len-t.shape[pad_axis] if pad_axis == 0 else 0),\
        (0, max_len-t.shape[pad_axis] if pad_axis == 1 else 0)), \
        mode="constant", constant_values=-1) for t in tuple],axis=cat_axis)

def initial_clusters(idxs, num_merges=2):
    for main_iter in range(num_merges):
        next_idxs = []
        distance_heap = []

        used = set()
        for i in (range(len(idxs)-1)):
            for j in range(i+1, len(idxs)):

                inner_dists = []
                for first_idx in idxs[i]:
                    for second_idx in idxs[j]:
                        inner_dists.append(get_dist(i,j, dists))
                try:
                    heappush(distance_heap, (sum(inner_dists) / len(inner_dists), i, j))
                except:
                    pdb.set_trace()

        while(len(distance_heap) > 0):
            _,i,j = heappop(distance_heap)
            if i not in used and j not in used:
                next_idxs.append(idxs[i] + idxs[j])
                used.add(i)
                used.add(j)
            
            if len(idxs) - len(used) < 3+main_iter:
                leftover = [idxs[leftover_idx] for leftover_idx in range(len(idxs)) if leftover_idx not in used]
                next_idxs.append(list(itertools.chain.from_iterable(leftover)))
                break
        
        idxs = next_idxs
    
    return idxs
