from argparse import ArgumentParser
from cmath import nan
import csv
from sklearn import neighbors
import pandas as pd


parser = ArgumentParser()
parser.add_argument('--trec_file', required=True)
parser.add_argument('--queries', required=True)
parser.add_argument('--collection', required=True)
parser.add_argument('--save_to', required=True)
parser.add_argument('--use_title', default=False)
parser.add_argument('--prf_k', type=int, default=3)
args = parser.parse_args()

# Read Docs
d_df=pd.read_csv(args.collection, sep='\t',header=None,names=['d-id', 'title','text'],index_col='d-id')

# Read Predicted Neighbors
with open(args.trec_file,'r') as rf:
	lines = rf.readlines()
cur_q, _, p_0, r_0, s_0, _ = lines[0].strip().split()
top_neighbors=[p_0]
augmented_q_dict={}
for line in lines:
	q, _, p, r, s, _ = line.strip().split()
	if q==cur_q:
		top_neighbors.append(p)
	else:
		top_neighbors=top_neighbors[:args.prf_k]
		augmented_q_dict[str(cur_q)]=top_neighbors
		cur_q=q
		top_neighbors=[p]
	top_neighbors=top_neighbors[:args.prf_k]
	augmented_q_dict[str(cur_q)]=top_neighbors

# Read Query Files
q_df=pd.read_csv(args.queries, sep='\t',header=None,names=['q-id', 'text'])

# Write Augmented Query Files
with open(args.save_to, 'w', newline='') as fout:
	tsv_w = csv.writer(fout, delimiter='\t')
	for i in range(len(q_df)):
		q_text=q_df.iat[i,1]
		q_id=q_df.iat[i, 0]
		augmented_d_ids=augmented_q_dict[str(q_id)]
		if args.use_title:
			for d_id in augmented_d_ids:
				try:
					d_text=d_df.loc[d_id, 'text']
					d_title=d_df.loc[d_id, 'title']
				except:
					d_text=d_df.loc[int(d_id), 'text']
					d_title=d_df.loc[int(d_id), 'title']
				if d_title==nan or None:
					d_title=""
				q_text+=' [SEP] '+str(d_title)+" "+d_text
		else:
			for d_id in augmented_d_ids:
				try:
					d_text=d_df.loc[d_id, 'text']
				except:
					d_text=d_df.loc[int(d_id), 'text']
				q_text+=' [SEP] '+d_text
		tsv_w.writerow([q_id,q_text])

print("over!")
