from argparse import ArgumentParser
from cmath import nan
import csv
from tkinter.font import names
from sklearn import neighbors
import pandas as pd
import json
import pickle
import numpy as np
import pdb
from transformers import AutoTokenizer,PreTrainedTokenizer
import os
import random
from tqdm import tqdm
from datetime import datetime
from multiprocessing import Pool
import sys
from dataclasses import dataclass
from typing import Dict
import datasets
@dataclass
class TrainPreProcessor:
    query_file: str
    collection_file: str
    tokenizer: PreTrainedTokenizer
    use_title: bool

    max_length: int = 128
    title_field = 'title'
    text_field = 'text'
    template: str = None
    columns=['text_id', 'title', 'text','name']

    def __post_init__(self):
        self.queries = self.read_queries(self.query_file)
        self.collection = datasets.load_dataset(
            'csv',
            data_files=self.collection_file,
            column_names=self.columns,
            delimiter='\t',
        )['train']

    @staticmethod
    def read_queries(queries):
        qmap = {}
        with open(queries) as f:
            for l in f:
                qid, qry = l.strip().split('\t')
                qmap[qid] = qry
        return qmap

    @staticmethod
    def read_qrel(relevance_file):
        qrel = {}
        with open(relevance_file, encoding='utf8') as f:
            tsvreader = csv.reader(f, delimiter="\t")
            for [topicid, _, docid, rel] in tsvreader:
                assert rel == "1"
                if topicid in qrel:
                    qrel[topicid].append(docid)
                else:
                    qrel[topicid] = [docid]
        return qrel

    def get_query(self, q):
        query_encoded = self.tokenizer.encode(
            self.queries[q],
            add_special_tokens=False,
            max_length=self.max_length,
            truncation=True
        )
        return query_encoded

    def get_passage(self, p):
        p=int(p)
        if p>(32407+8841823):
            p-=1
        entry = self.collection[p]
        if not self.use_title:
            title = ""
        else:
            title = entry[self.title_field]
        body = entry[self.text_field]
        name=entry['name']
        if self.template is None:
            content = title +" " +body
        elif title=="" or title is None:
            content=self.template.replace("<text>", body).replace("<name>", name)
        else:
            content = self.template.replace("<title>", title).replace("<text>", body).replace("<name>", name)
        passage_encoded = self.tokenizer.encode(
            content,
            add_special_tokens=False,
            max_length=self.max_length,
            truncation=True
        )

        return passage_encoded

    def process_one(self, train):
        q, pp, nn, ww, gg= train
        train_example = {
            'query': self.get_query(q),
            'positives': [self.get_passage(p) for p in pp],
            'negatives': [self.get_passage(n) for n in nn],
            'scores': ww.astype('float').tolist(),
            'grounds':[self.tokenizer.encode(g, add_special_tokens=False, max_length=self.max_length, truncation=True) for g in gg]
        }

        return json.dumps(train_example)

def load_ranking(rank_file, relevance, n_sample, depth, qid_weight_dict, qid_ground_dict,qid_ground_id_dict):
    with open(rank_file) as rf:
        lines = iter(rf)
        q_0, _, p_0, _, _, _ = next(lines).strip().split()

        curr_q = q_0
        positives=relevance[q_0]
        sorted_index=np.argsort(-qid_weight_dict[int(q_0)])
        top_ground_id=[qid_ground_id_dict[int(q_0)][index] for index in sorted_index[:5]]
        if p_0 in top_ground_id and (int(p_0) > 8841822):
            positives.append(p_0)
        if p_0 not in positives:
            negatives=[p_0]
        else:
            negatives=[]
 
        while True:
            try:
                q, _, p, _, _, _ = next(lines).strip().split()
                if q != curr_q:
                    negatives = negatives[:depth]
                    random.shuffle(negatives)
                    random.shuffle(positives)
                    yield curr_q, positives, negatives[:n_sample], qid_weight_dict[int(curr_q)], qid_ground_dict[int(curr_q)]
                    curr_q = q
                    sorted_index=np.argsort(-qid_weight_dict[int(q)])
                    top_ground_id=[qid_ground_id_dict[int(q)][index] for index in sorted_index[:5]]
                    positives=relevance[q]
                    if p in top_ground_id and (int(p) > 8841822):
                        positives.append(p)
                    if p not in positives:
                        negatives=[p]
                    else:
                        negatives=[]
                else:
                    if p in top_ground_id and (int(p) > 8841822):
                        positives.append(p)
                    if p not in positives:
                        negatives.append(p)
            except StopIteration:
                negatives = negatives[:depth]
                random.shuffle(negatives)
                random.shuffle(positives)
                yield curr_q, positives, negatives[:n_sample], qid_weight_dict[int(curr_q)], qid_ground_dict[int(curr_q)]
                return

parser = ArgumentParser()
parser.add_argument('--tokenizer_name', required=True)
parser.add_argument('--q_att_file_path', type=str,required=True)
parser.add_argument('--augmented_q_file_path', type=str,required=True)
parser.add_argument('--distill_train_file_dir', type=str,required=True)
parser.add_argument('--truncate', type=int, default=128)
parser.add_argument('--mp_chunk_size', type=int, default=500)
parser.add_argument('--shard_size', type=int, default=45000)
parser.add_argument('--hn_file', required=True)
parser.add_argument('--qrels', required=True)
parser.add_argument('--queries', required=True)
parser.add_argument('--collection', required=True)
parser.add_argument('--save_to', required=True)
parser.add_argument('--template', type=str, default=None)
parser.add_argument('--use_title', type=bool, default=False)
parser.add_argument('--n_sample', type=int, default=32)
parser.add_argument('--depth', type=int, default=200)

args = parser.parse_args()
with open(args.q_att_file_path,'rb') as fin:
	att_weights, qid_lookup_indices=pickle.load(fin)
qid_weight_dict={}
for i in range(len(qid_lookup_indices)):
	qid_weight_dict[int(qid_lookup_indices[i])]=att_weights[i][1:]

qid_ground_dict={}
qid_ground_id_dict={}
with open(args.augmented_q_file_path,'r',encoding='utf8') as fin:
	lines=iter(fin)
	while True:
		try:
			q_dict=json.loads(next(lines))
			q_id=q_dict['id']
			qid_ground_dict[int(q_id)]=q_dict['grounds']
			qid_ground_id_dict[int(q_id)]=q_dict['ground_ids']
		except StopIteration:
			break

qrel = TrainPreProcessor.read_qrel(args.qrels)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
processor = TrainPreProcessor(
    query_file=args.queries,
    collection_file=args.collection,
    tokenizer=tokenizer,
    max_length=args.truncate,
    template=args.template,
    use_title= args.use_title
)

counter = 0
shard_id = 0
f = None
os.makedirs(args.distill_train_file_dir, exist_ok=True)
pbar = tqdm(load_ranking(args.hn_file, qrel, args.n_sample, args.depth,qid_weight_dict,qid_ground_dict,qid_ground_id_dict))
with Pool() as p:
    for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size):
        counter += 1
        if f is None:
            f = open(os.path.join(args.distill_train_file_dir, f'split{shard_id:02d}.distill.jsonl'), 'w')
            pbar.set_description(f'split - {shard_id:02d}')
        f.write(x + '\n')

        if counter == args.shard_size:
            f.close()
            f = None
            shard_id += 1
            counter = 0

if f is not None:
    f.close()
