from argparse import ArgumentParser
from cmath import nan
import csv
from sklearn import neighbors
import pandas as pd
import json


parser = ArgumentParser()
parser.add_argument('--trec_file', required=True)
parser.add_argument('--queries', required=True)
parser.add_argument('--collection', required=True)
parser.add_argument('--save_to', required=True)
parser.add_argument('--use_title', default=False)
parser.add_argument('--prf_k', type=int, default=3)
args = parser.parse_args()

# Read Docs
d_df=pd.read_csv(args.collection, sep='\t',header=None,names=['d-id', 'title','text','name'],index_col='d-id')

# Read Predicted Neighbors
with open(args.trec_file,'r') as rf:
	lines = rf.readlines()
cur_q, _, p_0, r_0, s_0, _ = lines[0].strip().split()
top_neighbors=[p_0]
augmented_q_dict={}
for line in lines:
	q, _, p, r, s, _ = line.strip().split()
	if q==cur_q:
		top_neighbors.append(p)
	else:
		top_neighbors=top_neighbors[:args.prf_k]
		augmented_q_dict[str(cur_q)]=top_neighbors
		cur_q=q
		top_neighbors=[p]
top_neighbors=top_neighbors[:args.prf_k]
augmented_q_dict[str(cur_q)]=top_neighbors

# Read Query Files
q_df=pd.read_csv(args.queries, sep='\t',header=None,names=['q-id', 'text','grounds'])

# Write Augmented Query Files
with open(args.save_to, 'w') as fout:
	for i in range(len(q_df)):
		q_text=q_df.iat[i,1]
		q_id=q_df.iat[i, 0]
		augmented_d_ids=augmented_q_dict[str(q_id)]
		q_grounds=[]
		q_names=[]
		if args.use_title:
			for d_id in augmented_d_ids:
				try:
					d_text=d_df.loc[d_id, 'text']
					d_title=d_df.loc[d_id, 'title']
					d_name=d_df.loc[d_id, 'name']
				except:
					d_text=d_df.loc[int(d_id), 'text']
					d_title=d_df.loc[int(d_id), 'title']
					d_name=d_df.loc[int(d_id), 'name']
				if d_title==nan or None:
					d_title=""
				q_grounds.append("Corpus: "+d_name+" Title: "+str(d_title)+" Text: "+d_text)
		else:
			for d_id in augmented_d_ids:
				try:
					d_text=d_df.loc[d_id, 'text']
					d_name=d_df.loc[d_id, 'name']
				except:
					d_text=d_df.loc[int(d_id), 'text']
					d_name=d_df.loc[int(d_id), 'name']
				q_grounds.append("Corpus: "+d_name+" "+d_text)
		try:
			q_dict={'id':int(q_id),"text":q_text,"grounds":q_grounds,"ground_ids":augmented_d_ids}
		except:
			q_dict={'id':str(q_id),"text":q_text,"grounds":q_grounds,"ground_ids":augmented_d_ids}
		json.dump(q_dict, fout)
		fout.write('\n')
print("over!")