import argparse
import json
import gzip
import lzma

import tqdm
import torch
import numpy as np
from cffi import FFI

def neg_dot_product_distances(embedding, embeddings):
    return -1 * embedding.unsqueeze(0).mm(embeddings.T)[0]

def l2_distances(embedding, embeddings):
    return ((embeddings - embedding)**2).sum(dim=1).sqrt()

def init_levenshtein_c():
    ffibuilder = FFI()
    ffibuilder.set_source("_levenshtein",
        r"""
            int levenshtein(int *seq1, int seq1_len, int *seq2, int seq2_len, int *v0)
            {
                // Adapted from https://en.wikipedia.org/wiki/Levenshtein_distance  (CC-BY-SA)

                // v0 is just a buffer for temporary calculations; easier to
                // ask the caller to allocate it than to deal with C mem
                // management

                int substitutionCost, insertionCost, deletionCost;
                int tmpval;

                for (int i = 0; i < seq2_len+1; i++) {
                    v0[i] = i;
                }

                for (int i = 0; i < seq1_len; i++){
                    // calculate v1 (current row distances) from the previous row v0

                    // first element of v1 is A[i+1][0]
                    //   edit distance is delete (i+1) chars from s to match empty t
                    //v1[0] = i + 1;
                    tmpval = i + 1;

                    // use formula to fill in the rest of the row
                    for(int j = 0; j < seq2_len; j++){
                        // calculating costs for A[i+1][j+1]
                        deletionCost = v0[j + 1] + 1;
                        insertionCost = tmpval + 1;
                        substitutionCost = v0[j];
                        if (seq1[i] != seq2[j]) {
                            substitutionCost++;
                        }

                        v0[j] = tmpval;

                        tmpval = deletionCost;
                        if (insertionCost < tmpval) {
                            tmpval = insertionCost;
                        }
                        if (substitutionCost < tmpval) {
                            tmpval = substitutionCost;
                        }
                    }
                    v0[seq2_len] = tmpval;

                    // copy v1 (current row) to v0 (previous row) for next iteration
                    // since data in v1 is always invalidated, a swap without copy could be more efficient
                    //tmp = v0;
                    //v0 = v1;
                    //v1 = tmp;
                }
                // after the last swap, the results of v1 are now in v0
                return v0[seq2_len];
            }
        """);

    ffibuilder.cdef("int levenshtein(int*, int, int*, int, int*);");

    # Compile the C module and import it
    ffibuilder.compile(verbose=True)
    from _levenshtein import ffi, lib

    return ffi, lib

levenshtein_ffi, levenshtein_lib = None, None
def levenshtein_distance(seq1, seq2):
    # We call a C function for levenshtein via CFFI because it is about 1000x
    # faster than the python version (the difference between running in an hour
    # vs running in a month)

    global levenshtein_ffi, levenshtein_lib

    if levenshtein_ffi is None:
        levenshtein_ffi, levenshtein_lib = init_levenshtein_c()

    if len(seq1) > len(seq2):
        seq1, seq2 = seq2, seq1

    seq1_buf = levenshtein_ffi.cast("int*", levenshtein_ffi.from_buffer(np.array(seq1, dtype=np.int32)))
    seq2_buf = levenshtein_ffi.cast("int*", levenshtein_ffi.from_buffer(np.array(seq2, dtype=np.int32)))
    v0 = levenshtein_ffi.cast("int*", levenshtein_ffi.from_buffer(np.zeros(len(seq2) + 1, dtype=np.int32)))

    result = levenshtein_lib.levenshtein(seq1_buf, len(seq1), seq2_buf, len(seq2), v0)
    return result

def levenshtein_distances(embedding, embeddings):
    results = []
    for emb2 in embeddings:
        results.append(levenshtein_distance(embedding, emb2))
    return torch.tensor(results)

class CoresetBuilder:
    def __init__(self, embeddings, dist_fn, initial_size=1, rng=None):
        """Builds a coreset over `embeddings` based on `dist_fn`."""
        self.embeddings = embeddings
        self.dist_fn = dist_fn
        self.rng = rng if rng is not None else np.random.default_rng()

        self.full_data_size = len(embeddings)

        # Initialize coreset
        self._reset_coreset()
        if initial_size > 0:
            initial_set = set(rng.choice(list(self.unlabeled_pool), size=initial_size, replace=False).tolist())
            for idx in initial_set:
                self.add_to_coreset(idx)

    def _reset_coreset(self):
        # These sets just store indices into the full data to make them easier to manipulate
        self.unlabeled_pool = set(range(self.full_data_size))
        self.coreset = set()

        self.nearest_neighbor_dists = [float('inf') for _ in range(self.full_data_size)]

    def set_embeddings(self, new_embeddings):
        if len(new_embeddings) != len(self.embeddings):
            raise ValueError("New embeddings length ({}) should be the same as the old embeddings length ({})".format(len(new_embeddings), len(old_embeddings)))

        coreset_idxs = self.coreset.copy()
        self._reset_coreset()

        for idx in coreset_idxs:
            self.add_to_coreset(idx)

    def get_coreset_indices(self):
        return self.coreset.copy()

    def add_to_coreset(self, idx):
        self.unlabeled_pool.remove(idx)
        self.coreset.add(idx)

        dists = self.dist_fn(self.embeddings[idx], self.embeddings).numpy()
        for i in range(self.full_data_size):
            self.nearest_neighbor_dists[i] = min(self.nearest_neighbor_dists[i], dists[i])

    def acquire_points(self, n_points=1):
        """Acquire N new points into the coreset using the greedy coreset
        algorithm from (Sener and Savarese, 2018)"""
        if n_points == 0:
            return set()

        selected_idxs = set()
        for step in range(n_points):
            best_idx = -1
            best_dist = -float('inf')
            for idx in self.unlabeled_pool:
                dist = self.nearest_neighbor_dists[idx]
                if dist > best_dist:
                    best_idx = idx
                    best_dist = dist
            selected_idxs.add(best_idx)
            self.add_to_coreset(best_idx)
        return selected_idxs

    def __len__(self):
        return len(self.coreset)

