import argparse
import multiprocessing
import os
import time
import numpy as np
import mxnet as mx
import sys
sys.path.insert(0, "/open_clip")

from src.open_clip import tokenize


def read_worker(args, q_in):
    """worker for picking record
    """

    list_rec_read = args.rec_read.split(",")
    list_text = args.text.split(",")

    np_label = np.load(args.label)
    np_label_score = np.load(args.label_score)
    np_label_for_filter = np.load(args.label_for_filter)

    idx_global = 0
    for i, rec_read in enumerate(list_rec_read):
        record = mx.recordio.MXRecordIO(rec_read, 'r')
        assert os.path.exists(rec_read), f"{rec_read} not founding."

        with open(list_text[i], "r") as f:
            lines = f.readlines()

        idx_local = 0
        while True:
            bin = record.read()
            if bin is None:
                break
            header_1, jpeg = mx.recordio.unpack(bin)
            header_1: mx.recordio.IRHeader
            # GET
            text = lines[idx_local]
            token = tokenize(text).flatten().numpy()
            current_label = int(np_label[idx_global])
            current_label_score = float(np_label_score[idx_global][0])

            idx_local += 1
            idx_global += 1
            # CHECK 1
            assert isinstance(current_label, int)
            if current_label in np_label_for_filter:
                continue
            # CHECK 2
            assert isinstance(current_label_score, float)
            if current_label_score < 0.7:
                continue

            labels = np.zeros((79, ))
            labels[:77] = token
            labels[77] = current_label
            labels[78] = current_label_score

            q_in.put(mx.recordio.pack(
                header=mx.recordio.IRHeader(
                    header_1.flag, labels,
                    header_1.id, header_1.id2), s=jpeg))

        record.close()
    q_in.put(None)


def write_worker(args, q_out):
    pre_time = time.time()
    count = 0
    assert not os.path.exists(args.rec_save), "{} existing!.".format(args.rec_save)
    save_record = mx.recordio.MXRecordIO(args.rec_save, 'w')
    more = True
    none_count = 0
    while more:
        deq = q_out.get()
        if deq is None:
            none_count += 1
            if none_count == args.num_thread:
                more = False
        else:
            item = deq
            save_record.write(item)
            if count % 100000 == 0:
                cur_time = time.time()
                print('save time:', cur_time - pre_time, ' count:', count)
                pre_time = cur_time
            count += 1
    print(count)
    save_record.close()


def main(args):
    queue = multiprocessing.Queue(1024)
    read_process = multiprocessing.Process(target=read_worker, args=(args, queue))
    read_process.daemon = True
    read_process.start()
    write_process = multiprocessing.Process(target=write_worker, args=(args, queue))
    write_process.start()
    write_process.join()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--label',
                        help="npy path")
    parser.add_argument('--label-score',
                        help='npy path')
    parser.add_argument('--label-for-filter',
                        help='npy path')
    parser.add_argument('--text')
    parser.add_argument('--rec-read', help='path to source rec.')
    parser.add_argument('--rec-save', help='path to new rec.')
    parser.add_argument('--num-thread', type=int, default=1, help='')
    main(parser.parse_args())
