import argparse
import json
import os
import random
import pickle

import torch
import faiss
import h5py
import numpy as np
from tqdm import tqdm

from densephrases.utils.embed_utils import int8_to_float


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('dump_dir')
    parser.add_argument('stage')

    # large-scale add option
    parser.add_argument('--dump_paths', default=None,
                        help='Relative to `dump_dir/phrase`. '
                             'If specified, creates subindex dir and save there with same name')
    parser.add_argument('--subindex_name', default='index', help='used only if dump_path is specified.')
    parser.add_argument('--offset', default=0, type=int)

    # relative paths in dump_dir/index_name
    parser.add_argument('--quantizer_path', default='quantizer.faiss')
    parser.add_argument('--trained_index_path', default='trained.faiss')
    parser.add_argument('--phrase_path', default='phrases_min0.pkl')
    parser.add_argument('--index_path', default='index.faiss')
    parser.add_argument('--idx2id_path', default='idx2id.hdf5')
    parser.add_argument('--inv_path', default='merged.invdata')

    # adding options
    parser.add_argument('--add_all', default=False, action='store_true')

    # coarse, fine, add
    parser.add_argument('--num_clusters', type=int, default=16384)
    parser.add_argument('--hnsw', default=False, action='store_true')
    parser.add_argument('--fine_quant', default='SQ4',
                        help='SQ4|PQ# where # is number of bytes per vector')
    # stable params
    parser.add_argument('--norm_th', default=999, type=float)
    parser.add_argument('--para', default=False, action='store_true')
    parser.add_argument('--doc_sample_ratio', default=0.2, type=float)
    parser.add_argument('--vec_sample_ratio', default=0.2, type=float)
    parser.add_argument('--cuda', default=False, action='store_true')
    parser.add_argument('--replace', default=False, action='store_true')
    parser.add_argument('--num_phrases_per_add', default=50000, type=int)

    args = parser.parse_args()

    coarse = 'hnsw' if args.hnsw else 'flat'
    args.index_name = '%d_%s_%s' % (args.num_clusters, coarse, args.fine_quant)
    args.index_dir = os.path.join(args.dump_dir, args.phrase_path.split('.')[0], args.index_name)

    args.quantizer_path = os.path.join(args.index_dir, args.quantizer_path)
    args.trained_index_path = os.path.join(args.index_dir, args.trained_index_path)
    args.phrase_path = os.path.join(args.dump_dir, args.phrase_path)
    args.inv_path = os.path.join(args.index_dir, args.inv_path)

    args.subindex_dir = os.path.join(args.index_dir, args.subindex_name)
    if args.dump_paths is None:
        args.index_path = os.path.join(args.index_dir, args.index_path)
        args.idx2id_path = os.path.join(args.index_dir, args.idx2id_path)
    else:
        args.dump_paths = [os.path.join(args.dump_dir, 'phrase', path) for path in args.dump_paths.split(',')]
        args.index_path = os.path.join(args.subindex_dir, '%d.faiss' % args.offset)
        args.idx2id_path = os.path.join(args.subindex_dir, '%d.hdf5' % args.offset)

    return args


def concat_vectors(vectors):
    total_size = sum(vec.shape[0] for vec in vectors)
    if len(vectors[0].shape) > 1:
        out_vector = np.zeros((total_size, *vectors[0].shape[1:]), dtype=vectors[0].dtype)
    else:
        out_vector = np.zeros((total_size), dtype=vectors[0].dtype)
    vec_idx = 0
    for vec in vectors:
        out_vector[vec_idx:vec_idx+vec.shape[0]] = vec
        vec_idx += vec.shape[0]
    return out_vector


def sample_data(dump_paths, doc_sample_ratio=0.2, vec_sample_ratio=0.2, seed=29, norm_th=999):
    start_vecs = []
    end_vecs = []
    random.seed(seed)
    np.random.seed(seed)
    print('sampling from:')
    for dump_path in dump_paths:
        print(dump_path)
    dumps = [h5py.File(dump_path, 'r') for dump_path in dump_paths]
    for i, f in enumerate(tqdm(dumps)):
        doc_ids = list(f.keys())
        sampled_doc_ids = random.sample(doc_ids, int(doc_sample_ratio * len(doc_ids)))
        for doc_id in tqdm(sampled_doc_ids, desc='sampling from %d' % i):
            doc_group = f[doc_id]
            groups = [doc_group]
            for group in groups:
                start_set = group['start'][:]
                num_start, d = start_set.shape
                if num_start == 0: continue
                sampled_start_idxs = np.random.choice(num_start, int(vec_sample_ratio * num_start))
                start_vec = int8_to_float(start_set, group.attrs['offset'], group.attrs['scale'])[sampled_start_idxs]
                start_vec = start_vec[np.linalg.norm(start_vec, axis=1) <= norm_th]
                start_vecs.append(start_vec)

    start_out = concat_vectors(start_vecs)
    for dump in dumps:
        dump.close()

    return start_out


def train_index(start_data, quantizer_path, trained_index_path, num_clusters,
        fine_quant='SQ4', cuda=False, hnsw=False):
    ds = start_data.shape[1]
    quantizer = faiss.IndexFlatIP(ds)

    if fine_quant == 'SQ4':
        start_index = faiss.IndexIVFScalarQuantizer(
            quantizer, ds, num_clusters, faiss.ScalarQuantizer.QT_4bit, faiss.METRIC_INNER_PRODUCT
        )
    elif 'PQ' in fine_quant:
        code_size = int(fine_quant.split('_')[0][2:])
        bits_per_sub = int(fine_quant.split('_')[1])
        assert bits_per_sub == 8
        start_index = faiss.IndexIVFPQ(quantizer, ds, num_clusters, code_size, bits_per_sub, faiss.METRIC_INNER_PRODUCT)
    else:
        raise ValueError(fine_quant)

    start_index.verbose = True
    if cuda:
        # Convert to GPU index
        res = faiss.StandardGpuResources()
        co = faiss.GpuClonerOptions()
        co.useFloat16 = True
        gpu_index = faiss.index_cpu_to_gpu(res, 0, start_index, co)
        gpu_index.verbose = True

        # Train on GPU and back to CPU
        gpu_index.train(start_data)
        start_index = faiss.index_gpu_to_cpu(gpu_index)
    else:
        start_index.train(start_data)

    # Make sure to set direct map again
    start_index.make_direct_map()
    start_index.set_direct_map_type(faiss.DirectMap.Hashtable)
    faiss.write_index(start_index, trained_index_path)


def add_with_offset(start_index, start_data, start_total, offset):
    start_ids = (np.arange(start_data.shape[0]) + offset + start_total).astype(np.int64)
    start_index.add_with_ids(start_data, start_ids)


def get_doc_group(dump_ranges, phrase_dumps, doc_idx):
    if len(phrase_dumps) == 1:
        return phrase_dumps[0][str(doc_idx)]
    for dump_range, dump in zip(dump_ranges, phrase_dumps):
        if dump_range[0] * 1000 <= int(doc_idx) < dump_range[1] * 1000:
            if str(doc_idx) not in dump:
                raise ValueError('%d not found in dump list' % int(doc_idx))
            return dump[str(doc_idx)]

    # Check last
    if str(doc_idx) not in phrase_dumps[-1]:
        raise ValueError('%d not found in dump list' % int(doc_idx))
    else:
        return phrase_dumps[-1][str(doc_idx)]


def add_to_index(dump_paths, trained_index_path, phrase_path, target_index_path, idx2id_path,
                 num_phrases_per_add=1000, cuda=False, fine_quant='SQ4', offset=0, norm_th=999,
                 ignore_ids=None):

    sidx2doc_id = []
    sidx2word_id = []

    print(f'Reading HDF5 files')
    input_dumps = [h5py.File(path, 'r') for path in dump_paths]
    dump_names = [os.path.splitext(os.path.basename(path))[0] for path in dump_paths]
    dump_ranges = None
    if '-' in dump_names[0] and ('dev' not in dump_names[0]): # Range check
        dump_ranges = [list(map(int, name.split('-'))) for name in dump_names]
    
    # hdf5_doc_group = get_doc_group(dump_ranges, input_dumps, doc_idx)

    print('reading %s' % trained_index_path)
    start_index = faiss.read_index(trained_index_path)
    start_index.make_direct_map()
    start_index.set_direct_map_type(faiss.DirectMap.Hashtable)

    print(f'loading phrase {phrase_path}')
    phrases = pickle.load(open(phrase_path, 'rb'))

    if cuda:
        if 'PQ' in fine_quant:
            index_ivf = faiss.extract_index_ivf(start_index)
            quantizer = index_ivf.quantizer
            quantizer_gpu = faiss.index_cpu_to_all_gpus(quantizer)
            index_ivf.quantizer = quantizer_gpu
        else:
            res = faiss.StandardGpuResources()
            co = faiss.GpuClonerOptions()
            co.useFloat16 = True
            start_index = faiss.index_cpu_to_gpu(res, 0, start_index, co)

    print('adding following dumps:')
    for dump_path in dump_paths:
        print(dump_path)
    start_total = 0
    start_total_prev = 0
    cnt = 0
    offset_, scale_ = None, None
    starts = []
    cache = {}

    for pi, (phrase, doc_start_ends) in enumerate(tqdm(phrases.items(), desc='dumps')):
        for dse in doc_start_ends:
            dse = [int(dd) for dd in dse]
            orig_dse = dse[:]
            if f'{dse[0]}_{dse[1]}' in cache:
                dse[1] = dse[2]
            if f'{dse[0]}_{dse[2]}' in cache:
                dse[2] = dse[1]
            if dse[1] == dse[2]:
                word_id = [dse[1]]
                if f'{dse[0]}_{dse[1]}' in cache:
                    continue

            doc_group = get_doc_group(dump_ranges, input_dumps, dse[0])
            if offset_ is None:
                offset_, scale_ = doc_group.attrs['offset'], doc_group.attrs['scale']

            if dse[1] == dse[2]:
                word_id = [dse[1]]
                start_end = int8_to_float(doc_group['start'][word_id], offset_, scale_)
            else:
                word_id = [dse[1], dse[2]]
                start_end = int8_to_float(doc_group['start'][word_id], offset_, scale_)
            
            for wid in word_id:
                assert f'{dse[0]}_{wid}' not in cache
                # cache.append(f'{dse[0]}_{wid}')
                cache[f'{dse[0]}_{wid}'] = None

            num_vec = start_end.shape[0]
            num_vec == len(word_id)

            starts.append(start_end)
            sidx2doc_id.extend([int(dse[0])] * num_vec)
            sidx2word_id.extend(word_id)
            start_total += num_vec
            cnt += num_vec

        if len(starts) > 0 and (pi % num_phrases_per_add == 0):
            print(f'adding at {pi+1}, cache_size={len(cache)}')
            add_with_offset(
                start_index, concat_vectors(starts), start_total_prev, offset
            )
            start_total_prev = start_total
            starts = []

    if len(starts) > 0:
        print(f'final adding at {pi+1}, cache_size={len(cache)}')
        add_with_offset(
            start_index, concat_vectors(starts), start_total_prev, offset
        )
        start_total_prev = start_total
    print('number of phrases', cnt)

    if cuda:
        print('moving back to cpu')
        if 'PQ' in fine_quant:
            index_ivf.quantizer = quantizer
            del quantizer_gpu
        else:
            start_index = faiss.index_gpu_to_cpu(start_index)

    print('start_index ntotal: %d' % start_index.ntotal)
    print(start_total)
    sidx2doc_id = np.array(sidx2doc_id, dtype=np.int32)
    sidx2word_id = np.array(sidx2word_id, dtype=np.int32)

    print('writing index and metadata')
    with h5py.File(idx2id_path, 'w') as f:
        g = f.create_group(str(offset))
        g.create_dataset('doc', data=sidx2doc_id)
        g.create_dataset('word', data=sidx2word_id)
        g.attrs['offset'] = offset

    faiss.write_index(start_index, target_index_path)
    print('done')


def merge_indexes(subindex_dir, trained_index_path, target_index_path, target_idx2id_path, target_inv_path):
    # target_inv_path = merged_index.ivfdata
    names = os.listdir(subindex_dir)
    idx2id_paths = [os.path.join(subindex_dir, name) for name in names if name.endswith('.hdf5')]
    index_paths = [os.path.join(subindex_dir, name) for name in names if name.endswith('.faiss')]
    print(len(idx2id_paths))
    print(len(index_paths))

    print('copying idx2id')
    with h5py.File(target_idx2id_path, 'w') as out:
        for idx2id_path in tqdm(idx2id_paths, desc='copying idx2id'):
            with h5py.File(idx2id_path, 'r') as in_:
                for key, g in in_.items():
                    offset = str(g.attrs['offset'])
                    assert key == offset
                    group = out.create_group(offset)
                    group.create_dataset('doc', data=in_[key]['doc'])
                    group.create_dataset('word', data=in_[key]['word'])

    print('loading invlists')
    ivfs = []
    for index_path in tqdm(index_paths, desc='loading invlists'):
        # the IO_FLAG_MMAP is to avoid actually loading the data thus
        # the total size of the inverted lists can exceed the
        # available RAM
        index = faiss.read_index(index_path,
                                 faiss.IO_FLAG_MMAP)
        ivfs.append(index.invlists)

        # avoid that the invlists get deallocated with the index
        index.own_invlists = False

    # construct the output index
    index = faiss.read_index(trained_index_path)

    # prepare the output inverted lists. They will be written
    # to merged_index.ivfdata
    invlists = faiss.OnDiskInvertedLists(
        index.nlist, index.code_size,
        target_inv_path)

    # merge all the inverted lists
    print('merging')
    ivf_vector = faiss.InvertedListsPtrVector()
    for ivf in tqdm(ivfs):
        ivf_vector.push_back(ivf)

    print("merge %d inverted lists " % ivf_vector.size())
    ntotal = invlists.merge_from(ivf_vector.data(), ivf_vector.size())
    print(ntotal)

    # now replace the inverted lists in the output index
    index.ntotal = ntotal
    index.replace_invlists(invlists)

    print('writing index')
    faiss.write_index(index, target_index_path)


def run_index(args):
    dump_names = os.listdir(os.path.join(args.dump_dir, 'phrase'))
    dump_paths = sorted([os.path.join(args.dump_dir, 'phrase', name) for name in dump_names if name.endswith('.hdf5')])

    data = None
    if args.stage in ['all', 'coarse']:
        if args.replace:
            if not os.path.exists(args.index_dir):
                os.makedirs(args.index_dir)
            start_data = sample_data(
                dump_paths, doc_sample_ratio=args.doc_sample_ratio, vec_sample_ratio=args.vec_sample_ratio,
                norm_th=args.norm_th
            )

    if args.stage in ['all', 'fine']:
        if args.replace or not os.path.exists(args.trained_index_path):
            if start_data is None:
                start_data = sample_data(
                    dump_paths,
                    doc_sample_ratio=args.doc_sample_ratio, vec_sample_ratio=args.vec_sample_ratio,
                    norm_th=args.norm_th,
                    hnsw=args.hnsw
                )
            train_index(
                start_data, args.quantizer_path, args.trained_index_path, args.num_clusters,
                fine_quant=args.fine_quant, cuda=args.cuda, hnsw=args.hnsw
            )

    if args.stage in ['all', 'add']:
        if args.replace or not os.path.exists(args.index_path):
            if args.dump_paths is not None:
                dump_paths = args.dump_paths
                if not os.path.exists(args.subindex_dir):
                    os.makedirs(args.subindex_dir)
            add_to_index(
                dump_paths, args.trained_index_path, args.phrase_path, args.index_path, args.idx2id_path,
                cuda=args.cuda, num_phrases_per_add=args.num_phrases_per_add, offset=args.offset, norm_th=args.norm_th,
                fine_quant=args.fine_quant
            )

    if args.stage == 'merge':
        if args.replace or not os.path.exists(args.index_path):
            merge_indexes(args.subindex_dir, args.trained_index_path, args.index_path, args.idx2id_path, args.inv_path)

    if args.stage == 'move':
        index = faiss.read_index(args.trained_index_path)
        invlists = faiss.OnDiskInvertedLists(
            index.nlist, index.code_size,
            args.inv_path)
        index.replace_invlists(invlists)
        faiss.write_index(index, args.index_path)


def main():
    args = get_args()
    run_index(args)


if __name__ == '__main__':
    main()
