from argparse import ArgumentParser
from transformers import AutoTokenizer
import os
import random
from tqdm import tqdm
from datetime import datetime
from multiprocessing import Pool
import sys
from lib.openmatch.utils import SimpleTrainPreProcessor as TrainPreProcessor

def load_ranking(rank_file, relevance, n_sample, depth,ratio_new):
    float_ratio_new=ratio_new/10
    num_new=int(n_sample*float_ratio_new)
    with open(rank_file) as rf:
        lines = iter(rf)
        q_0, _, p_0, _, _, source_0 = next(lines).strip().split()
        curr_q = q_0
        negatives_n = []
        negatives_l = []
        if p_0 not in relevance[q_0]:
            if source_0=='LastIter':
                negatives_l.append(p_0)
            else:
                negatives_n.append(p_0)
        while True:
            try:
                q, _, p, _, _, source = next(lines).strip().split()
                if q != curr_q:
                    negatives_n = negatives_n[:depth]
                    negatives_l = negatives_l[:depth]
                    random.shuffle(negatives_n)
                    random.shuffle(negatives_l)
                    negatives=negatives_n[:num_new]
                    for neg_l in negatives_l:
                        if len(negatives)>=n_sample:
                            break
                        elif neg_l not in negatives:
                            negatives.append(neg_l)
                    random.shuffle(negatives)
                    yield curr_q, relevance[curr_q], negatives
                    
                    curr_q = q
                    negatives_n = []
                    negatives_l = []
                    if p not in relevance[q]:
                        if source=='LastIter':
                            negatives_l.append(p)
                        else:
                            negatives_n.append(p)
                else:
                    if p not in relevance[q]:
                        if source=='LastIter':
                            negatives_l.append(p)
                        else:
                            negatives_n.append(p)
            except StopIteration:
                negatives_n = negatives_n[:depth]
                negatives_l = negatives_l[:depth]
                random.shuffle(negatives_n)
                random.shuffle(negatives_l)
                negatives=negatives_n[:num_new]
                for neg_l in negatives_l:
                    if len(negatives)>=n_sample:
                        break
                    elif neg_l not in negatives:
                        negatives.append(neg_l)
                    random.shuffle(negatives)
                yield curr_q, relevance[curr_q], negatives
                return

random.seed(datetime.now())
parser = ArgumentParser()
parser.add_argument('--tokenizer_name', required=True)
parser.add_argument('--hn_file', required=True)
parser.add_argument('--ratio_new', type=int,required=True)
parser.add_argument('--qrels', required=True)
parser.add_argument('--queries', required=True)
parser.add_argument('--collection', required=True)
parser.add_argument('--save_to', required=True)
parser.add_argument('--template', type=str, default=None)
parser.add_argument('--use_title', type=bool, default=False)
parser.add_argument('--truncate', type=int, default=128)
parser.add_argument('--n_sample', type=int, default=32)
parser.add_argument('--depth', type=int, default=200)
parser.add_argument('--mp_chunk_size', type=int, default=500)
parser.add_argument('--shard_size', type=int, default=45000)

args = parser.parse_args()

qrel = TrainPreProcessor.read_qrel(args.qrels)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
processor = TrainPreProcessor(
    query_file=args.queries,
    collection_file=args.collection,
    tokenizer=tokenizer,
    max_length=args.truncate,
    template=args.template,
    use_title= args.use_title
)

counter = 0
shard_id = 0
f = None
os.makedirs(args.save_to, exist_ok=True)

pbar = tqdm(load_ranking(args.hn_file,qrel, args.n_sample, args.depth,args.ratio_new))
with Pool() as p:
    for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size):
        counter += 1
        if f is None:
            f = open(os.path.join(args.save_to, f'split{shard_id:02d}.hn.jsonl'), 'w')
            pbar.set_description(f'split - {shard_id:02d}')
        f.write(x + '\n')

        if counter == args.shard_size:
            f.close()
            f = None
            shard_id += 1
            counter = 0

if f is not None:
    f.close()