from argparse import ArgumentParser
from transformers import AutoTokenizer
import os
import random
from tqdm import tqdm
from datetime import datetime
from multiprocessing import Pool
import sys

import csv
import json
from dataclasses import dataclass
from typing import Dict
import datasets
from transformers import PreTrainedTokenizer
@dataclass
class TrainPreProcessor:
    query_file: str
    collection_file: str
    tokenizer: PreTrainedTokenizer
    use_title: bool

    max_length: int = 128
    title_field = 'title'
    text_field = 'text'
    template: str = None
    columns=['text_id', 'title', 'text']

    def __post_init__(self):
        self.queries = self.read_queries(self.query_file)
        self.collection = datasets.load_dataset(
            'csv',
            data_files=self.collection_file,
            column_names=self.columns,
            delimiter='\t',
        )['train']

    @staticmethod
    def read_queries(queries):
        qmap = {}
        with open(queries) as f:
            for l in f:
                qid, qry = l.strip().split('\t')
                qmap[qid] = qry
        return qmap

    @staticmethod
    def read_qrel(relevance_file):
        qrel = {}
        with open(relevance_file, encoding='utf8') as f:
            tsvreader = csv.reader(f, delimiter="\t")
            for [topicid, _, docid, rel] in tsvreader:
                assert rel == "1"
                if topicid in qrel:
                    qrel[topicid].append(docid)
                else:
                    qrel[topicid] = [docid]
        return qrel

    def get_query(self, q):
        query_encoded = self.tokenizer.encode(
            self.queries[q],
            add_special_tokens=False,
            max_length=self.max_length,
            truncation=True
        )
        return query_encoded

    def get_passage(self, p):
        entry = self.collection[int(p)]
        if not self.use_title:
            title = ""
        else:
            title = entry[self.title_field]
        body = entry[self.text_field]
        if self.template is None:
            content = title +" " +body
        elif title=="" or title is None:
            content=self.template.replace("<text>", body)
        else:
            content = self.template.replace("<title>", title).replace("<text>", body)
        passage_encoded = self.tokenizer.encode(
            content,
            add_special_tokens=False,
            max_length=self.max_length,
            truncation=True
        )

        return passage_encoded

    def process_one(self, train):
        q, pp, nn = train
        train_example = {
            'query': self.get_query(q),
            'positives': [self.get_passage(p) for p in pp],
            'negatives': [self.get_passage(n) for n in nn],
        }

        return json.dumps(train_example)

random.seed(datetime.now())
parser = ArgumentParser()
parser.add_argument('--tokenizer_name', required=True)
parser.add_argument('--negative_file', required=True)
parser.add_argument('--qrels', required=True)
parser.add_argument('--queries', required=True)
parser.add_argument('--collection', required=True)
parser.add_argument('--save_to', required=True)
parser.add_argument('--template', type=str, default=None)
parser.add_argument('--use_title', type=bool, default=False)
parser.add_argument('--truncate', type=int, default=128)
parser.add_argument('--n_sample', type=int, default=32)
parser.add_argument('--mp_chunk_size', type=int, default=500)
parser.add_argument('--shard_size', type=int, default=45000)

args = parser.parse_args()


qrel = TrainPreProcessor.read_qrel(args.qrels)

def read_line(l):
    q, nn = l.strip().split('\t')
    nn = nn.split(',')
    random.shuffle(nn)
    return q, qrel[q], nn[:args.n_sample]

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
processor = TrainPreProcessor(
    query_file=args.queries,
    collection_file=args.collection,
    tokenizer=tokenizer, 
    max_length=args.truncate,
    template=args.template,
    use_title= args.use_title
)
counter = 0
shard_id = 0
f = None
os.makedirs(args.save_to, exist_ok=True)

with open(args.negative_file) as nf:
    pbar = tqdm(map(read_line, nf))
    with Pool() as p:
        for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size):
            counter += 1
            if f is None:
                f = open(os.path.join(args.save_to, f'split{shard_id:02d}.jsonl'), 'w')
                pbar.set_description(f'split - {shard_id:02d}')
            f.write(x + '\n')

            if counter == args.shard_size:
                f.close()
                f = None
                shard_id += 1
                counter = 0

if f is not None:
    f.close()