from argparse import ArgumentParser
from transformers import AutoTokenizer
import os
import random
from tqdm import tqdm
from datetime import datetime
from multiprocessing import Pool
import sys

import csv
import json
from dataclasses import dataclass
from typing import Dict
import datasets
from transformers import PreTrainedTokenizer
@dataclass
class TrainPreProcessor:
    query_file: str
    collection_file: str
    tokenizer: PreTrainedTokenizer
    use_title: bool

    max_length: int = 128
    title_field = 'title'
    text_field = 'text'
    template: str = None
    columns=['text_id', 'title', 'text']

    def __post_init__(self):
        self.queries = self.read_queries(self.query_file)
        self.collection = datasets.load_dataset(
            'csv',
            data_files=self.collection_file,
            column_names=self.columns,
            delimiter='\t',
        )['train']

    @staticmethod
    def read_queries(queries):
        qmap = {}
        with open(queries) as f:
            for l in f:
                qid, qry = l.strip().split('\t')
                qmap[qid] = qry
        return qmap

    @staticmethod
    def read_qrel(relevance_file):
        qrel = {}
        with open(relevance_file, encoding='utf8') as f:
            tsvreader = csv.reader(f, delimiter="\t")
            for [topicid, _, docid, rel] in tsvreader:
                assert rel == "1"
                if topicid in qrel:
                    qrel[topicid].append(docid)
                else:
                    qrel[topicid] = [docid]
        return qrel

    def get_query(self, q):
        query_encoded = self.tokenizer.encode(
            self.queries[q],
            add_special_tokens=False,
            max_length=self.max_length,
            truncation=True
        )
        return query_encoded

    def get_passage(self, p):
        entry = self.collection[int(p)]
        if not self.use_title:
            title = ""
        else:
            title = entry[self.title_field]
        body = entry[self.text_field]
        if self.template is None:
            content = title +" " +body
        elif title=="" or title is None:
            content=self.template.replace("<text>", body)
        else:
            content = self.template.replace("<title>", title).replace("<text>", body)
        passage_encoded = self.tokenizer.encode(
            content,
            add_special_tokens=False,
            max_length=self.max_length,
            truncation=True
        )
        return passage_encoded

    def get_ground(self, q, g_list):
        query=self.queries[q]
        grounds_encoded=[]
        for content in g_list:
            content=query+' [SEP] '+content
            passage_encoded = self.tokenizer.encode(
                content,
                add_special_tokens=False,
                max_length=self.max_length,
                truncation=True
            )
            grounds_encoded.append(passage_encoded)
        return grounds_encoded

    def process_one(self, train):
        q, pp, nn,gg = train
        train_example = {
            'query': self.get_query(q),
            'positives': [self.get_passage(p) for p in pp],
            'negatives': [self.get_passage(n) for n in nn],
            'grounds': self.get_ground(q,gg)
        }

        return json.dumps(train_example)

def load_ranking(rank_file, relevance, augmented_q_dict, n_sample, depth):
    with open(rank_file) as rf:
        lines = iter(rf)
        q_0, _, p_0, _, _, _ = next(lines).strip().split()

        curr_q = q_0
        negatives = [] if p_0 in relevance[q_0] else [p_0]

        while True:
            try:
                q, _, p, _, _, _ = next(lines).strip().split()
                if q != curr_q:
                    negatives = negatives[:depth]
                    random.shuffle(negatives)
                    yield curr_q, relevance[curr_q], negatives[:n_sample],augmented_q_dict[str(curr_q)]
                    curr_q = q
                    negatives = [] if p in relevance[q] else [p]
                else:
                    if p not in relevance[q]:
                        negatives.append(p)
            except StopIteration:
                negatives = negatives[:depth]
                random.shuffle(negatives)
                yield curr_q, relevance[curr_q], negatives[:n_sample], augmented_q_dict[str(curr_q)]
                return


random.seed(datetime.now())
parser = ArgumentParser()
parser.add_argument('--tokenizer_name', required=True)
parser.add_argument('--hn_file', required=True)
parser.add_argument('--qrels', required=True)
parser.add_argument('--queries', required=True)
parser.add_argument('--collection', required=True)
parser.add_argument('--save_to', required=True)
parser.add_argument('--template', type=str, default=None)
parser.add_argument('--use_title', type=bool, default=False)
parser.add_argument('--truncate', type=int, default=200)
parser.add_argument('--n_sample', type=int, default=32)
parser.add_argument('--depth', type=int, default=200)
parser.add_argument('--mp_chunk_size', type=int, default=500)
parser.add_argument('--shard_size', type=int, default=45000)
parser.add_argument('--augmented_q_file_path', type=str,required=True)

args = parser.parse_args()

qrel = TrainPreProcessor.read_qrel(args.qrels)

augmented_q_dict={}
with open(args.augmented_q_file_path,'r',encoding='utf8') as fin:
		lines=iter(fin)
		while True:
			try:
				q_dict=json.loads(next(lines))
				q_id=q_dict['id']
				augmented_q_dict[str(q_id)]=q_dict['grounds']
			except StopIteration:
				break

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
processor = TrainPreProcessor(
    query_file=args.queries,
    collection_file=args.collection,
    tokenizer=tokenizer,
    max_length=args.truncate,
    template=args.template,
    use_title= args.use_title
)

counter = 0
shard_id = 0
f = None
os.makedirs(args.save_to, exist_ok=True)

pbar = tqdm(load_ranking(args.hn_file, qrel, augmented_q_dict, args.n_sample, args.depth))
with Pool() as p:
    for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size):
        counter += 1
        if f is None:
            f = open(os.path.join(args.save_to, f'split{shard_id:02d}.hn.jsonl'), 'w')
            pbar.set_description(f'split - {shard_id:02d}')
        f.write(x + '\n')

        if counter == args.shard_size:
            f.close()
            f = None
            shard_id += 1
            counter = 0

if f is not None:
    f.close()