from argparse import ArgumentParser
from cmath import nan
import csv
from sklearn import neighbors
import pandas as pd
import json
import pickle
import numpy as np
import pdb


parser = ArgumentParser()
parser.add_argument('--truncate_doc_num', type=int,required=True)
parser.add_argument('--q_att_file_path', type=str,required=True)
parser.add_argument('--augmented_q_file_path', type=str,required=True)
parser.add_argument('--reranked_q_file_path', type=str,required=True)
args = parser.parse_args()

with open(args.q_att_file_path,'rb') as fin:
	att_weights, qid_lookup_indices=pickle.load(fin)

qid_weight_dict={}
for i in range(len(qid_lookup_indices)):
	qid_weight_dict[int(qid_lookup_indices[i])]=att_weights[i]

all_ranks=[]
all_indices=[]

with open(args.reranked_q_file_path, 'w') as fout:
	with open(args.augmented_q_file_path,'r') as fin:
		for line in fin.readlines():
			q_dict = json.loads(line)
			q_id=q_dict['id']
			q_text=q_dict['text']
			q_grounds=q_dict['grounds']
			rank=np.argsort(-qid_weight_dict[int(q_id)])
			
			all_ranks.append(rank)
			all_index=[]
			
			q_new_ground=[]
			for index in rank:
				if index!=0 and len(q_new_ground)<args.truncate_doc_num:
					q_new_ground.append(q_grounds[index-1])
					
					all_index.append(index-1)
			all_indices.append(all_index)

			try:
				q_reranked_dict={'id':int(q_id),"text":q_text,"grounds":q_new_ground}
			except:
				q_reranked_dict={'id':str(q_id),"text":q_text,"grounds":q_new_ground}
			json.dump(q_reranked_dict, fout)
			fout.write('\n')

all_ranks=np.array(all_ranks)

pdb.set_trace()
print("finish reranking grounding docs!")

