from argparse import ArgumentParser
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedTokenizer
import os
import csv
import json
import datasets
import random
from tqdm import tqdm
from datetime import datetime
from multiprocessing import Pool
from dataclasses import dataclass


def load_queries(queries_file):
    queries = {}
    with open(queries_file, 'r', encoding='utf8') as fIn:
        for idx, line in tqdm(enumerate(fIn)):
            qid, query = line.strip().split("\t")
            qid = int(qid)
            queries[qid] = query
    return queries


def load_corpus(collection_file):
    corpus = {}
    with open(collection_file, 'r', encoding='utf-8') as fIn:
        for idx, line in tqdm(enumerate(fIn)):
            pid, passage = line.strip().split("\t")
            pid = int(pid)
            passage = passage
            corpus[pid] = passage
    return corpus


def read_line(l):
    data = json.loads(l)
    qid = data["qid"]
    pos = data["pos"]
    neg = data["neg"]
    random.shuffle(neg)
    return qid, pos, neg


@dataclass
class GPTTrainPreProcessor:
    queries: dict
    corpus: dict
    tokenizer: PreTrainedTokenizer
    max_length: int = 128
    
    def __post_init__(self):
        self.bos_spec_token_q_rep = self.tokenizer.encode("[", add_special_tokens=False)[0]
        self.eos_spec_token_q = self.tokenizer.encode("]", add_special_tokens=False)[0]
        
        self.bos_spec_token_d_rep = self.tokenizer.encode("{", add_special_tokens=False)[0]
        self.eos_spec_token_d = self.tokenizer.encode("}", add_special_tokens=False)[0]
        
    def get_query(self, qid):
        query_encoded = self.tokenizer.encode(
            self.queries[qid],
            add_special_tokens=False,
            max_length=self.max_length,
            truncation=True
        )
        return [self.bos_spec_token_q_rep] + query_encoded + [self.eos_spec_token_q]

    def get_passage(self, pid):
        passage = self.corpus[pid]
        passage_encoded = self.tokenizer.encode(
            passage,
            add_special_tokens=False,
            max_length=self.max_length,
            truncation=True
        )

        return [self.bos_spec_token_d_rep] + passage_encoded + [self.eos_spec_token_d]

    def process_one(self, train):
        q, pp, nn = train
        train_example = {
            'query': self.get_query(q),
            'positives': [self.get_passage(p) for p in pp],
            'negatives': [self.get_passage(n) for n in nn],
        }

        return json.dumps(train_example)
    
    

random.seed(datetime.now())
parser = ArgumentParser()
parser.add_argument('--train_file', required=True)
parser.add_argument('--tokenizer_name', required=True)
parser.add_argument('--queries', required=True)
parser.add_argument('--collection', required=True)
parser.add_argument('--save_to', required=True)

parser.add_argument('--truncate', type=int, default=128)
parser.add_argument('--shard_size', type=int, default=45000)
parser.add_argument('--mp_chunk_size', type=int, default=500)

args = parser.parse_args()

queries = load_queries(args.queries)
corpus = load_corpus(args.collection)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)

processor = GPTTrainPreProcessor(
    queries=queries,
    corpus=corpus,
    tokenizer=tokenizer,
    max_length=args.truncate,
)

counter = 0
shard_id = 0
f = None
os.makedirs(args.save_to, exist_ok=True)

with open(args.train_file, "r", encoding="utf-8") as nf:
    pbar = tqdm(map(read_line, nf))
    with Pool() as p:
        for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size):
            counter += 1
            if f is None:
                f = open(os.path.join(args.save_to, f'split{shard_id:02d}.json'), 'w')
                pbar.set_description(f'split - {shard_id:02d}')
            f.write(x + '\n')

            if counter == args.shard_size:
                f.close()
                f = None
                shard_id += 1
                counter = 0

if f is not None:
    f.close()