from argparse import ArgumentParser
from cmath import nan
import csv
from sklearn import neighbors
import pandas as pd
import json
import pickle
import numpy as np
import pdb
from transformers import AutoTokenizer,PreTrainedTokenizer
import os
import random
from tqdm import tqdm
from datetime import datetime
from multiprocessing import Pool
import sys
from dataclasses import dataclass
from typing import Dict
import datasets
@dataclass
class TrainPreProcessor:
	tokenizer: PreTrainedTokenizer
	truncate: int = 128

	def process_one(self,train):
		q, pp, ance_pp, ss = train
		train_example = {
			'query': self.tokenizer.encode(q, add_special_tokens=False, max_length=self.truncate, truncation=True),
			'passages': [self.tokenizer.encode(p, add_special_tokens=False, max_length=self.truncate, truncation=True) for p in pp],
			'scores': ss.astype('float').tolist()
		}
		return json.dumps(train_example)


def load_query(augmented_q_file, qid_weight_dict):
	with open(augmented_q_file,'r',encoding='utf8') as fin:
		lines=iter(fin)
		while True:
			try:
				q_dict=json.loads(next(lines))
				q_id=q_dict['id']
				q_text=q_dict['text']
				q_grounds=q_dict['grounds']
				ground_scores=qid_weight_dict[int(q_id)][1:] 
				# [1:] ignore cross attention towards original query
				yield q_text,q_grounds,ground_scores
			except StopIteration:
				return

parser = ArgumentParser()
parser.add_argument('--tokenizer_name', required=True)
parser.add_argument('--q_att_file_path', type=str,required=True)
parser.add_argument('--augmented_q_file_path', type=str,required=True)
parser.add_argument('--distill_train_file_dir', type=str,required=True)
parser.add_argument('--truncate', type=int, default=128)
parser.add_argument('--mp_chunk_size', type=int, default=500)
parser.add_argument('--shard_size', type=int, default=45000)
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
with open(args.q_att_file_path,'rb') as fin:
	att_weights, qid_lookup_indices=pickle.load(fin)

qid_weight_dict={}
for i in range(len(qid_lookup_indices)):
	qid_weight_dict[int(qid_lookup_indices[i])]=att_weights[i]

processor = TrainPreProcessor(tokenizer=tokenizer,truncate=args.truncate)

counter = 0
shard_id = 0
f = None
os.makedirs(args.distill_train_file_dir, exist_ok=True)

pbar = tqdm(load_query(args.augmented_q_file_path,qid_weight_dict))
with Pool() as p:
    for x in p.imap(processor.process_one, pbar, chunksize=args.mp_chunk_size):
        counter += 1
        if f is None:
            f = open(os.path.join(args.distill_train_file_dir, f'split{shard_id:02d}.distill.jsonl'), 'w')
            pbar.set_description(f'split - {shard_id:02d}')
        f.write(x + '\n')

        if counter == args.shard_size:
            f.close()
            f = None
            shard_id += 1
            counter = 0

if f is not None:
    f.close()
