from argparse import ArgumentParser
from cmath import nan
import csv
from sklearn import neighbors
import pandas as pd
import json
import pdb

parser = ArgumentParser()
parser.add_argument('--trec_file', required=True)
parser.add_argument('--queries', required=True)
parser.add_argument('--collection', required=True)
parser.add_argument('--save_to', required=True)
parser.add_argument('--use_title', default=False)
parser.add_argument('--threshold', type=int, default=5)
parser.add_argument('--prf_k', type=int, default=10)
args = parser.parse_args()

# Read Docs
d_df=pd.read_csv(args.collection, sep='\t',header=None,names=['d-id', 'title','text','name'],index_col='d-id')

# Read Predicted Neighbors
with open(args.trec_file,'r') as rf:
	lines = rf.readlines()
cur_q, _, p_0, r_0, s_0, _ = lines[0].strip().split()
top_neighbors=[p_0]
augmented_q_dict={}
for line in lines:
	q, _, p, r, s, _ = line.strip().split()
	if q==cur_q:
		top_neighbors.append(p)
	else:
		top_neighbors=top_neighbors
		augmented_q_dict[str(cur_q)]=top_neighbors
		cur_q=q
		top_neighbors=[p]
top_neighbors=top_neighbors
augmented_q_dict[str(cur_q)]=top_neighbors

# Read Query Files
q_df=pd.read_csv(args.queries, sep='\t',header=None,names=['q-id', 'text'])

# Write Augmented Query Files
with open(args.save_to, 'w') as fout:
	for i in range(len(q_df)):
		q_text=q_df.iat[i,1]
		q_id=q_df.iat[i, 0]
		in_domain_augmented_d_ids=[]
		out_domain_augmented_d_ids=[]
		for d_id in augmented_q_dict[str(q_id)][:args.prf_k]:
			try:
				d_name=d_df.loc[d_id, 'name']
			except:
				d_name=d_df.loc[int(d_id), 'name']
			if d_name=='maarco':
				in_domain_augmented_d_ids.append(d_id)
			else:
				out_domain_augmented_d_ids.append(d_id)
		assert len(in_domain_augmented_d_ids)+len(out_domain_augmented_d_ids)==args.prf_k
		if len(in_domain_augmented_d_ids)>=args.threshold:
			augmented_d_ids=augmented_q_dict[str(q_id)][:args.prf_k]
		else:
			for d_id in augmented_q_dict[str(q_id)][args.prf_k:]:
				try:
					d_name=d_df.loc[d_id, 'name']
				except:
					d_name=d_df.loc[int(d_id), 'name']
				if d_name=='maarco':
					in_domain_augmented_d_ids.append(d_id)
					if len(in_domain_augmented_d_ids)>=args.threshold:
						break
			augmented_d_ids=in_domain_augmented_d_ids+out_domain_augmented_d_ids[:(args.prf_k-len(in_domain_augmented_d_ids))]
		q_grounds=[]
		if args.use_title:
			for d_id in augmented_d_ids:
				try:
					d_text=d_df.loc[d_id, 'text']
					d_title=d_df.loc[d_id, 'title']
					d_name=d_df.loc[d_id, 'name']
				except:
					d_text=d_df.loc[int(d_id), 'text']
					d_title=d_df.loc[int(d_id), 'title']
					d_name=d_df.loc[int(d_id), 'name']
				if d_title==nan or None:
					d_title=""
				q_grounds.append("Corpus: "+d_name+" Title: "+str(d_title)+" Text: "+str(d_text))
		else:
			for d_id in augmented_d_ids:
				try:
					d_text=d_df.loc[d_id, 'text']
					d_name=d_df.loc[d_id, 'name']
				except:
					d_text=d_df.loc[int(d_id), 'text']
					d_name=d_df.loc[int(d_id), 'name']
				q_grounds.append("Corpus: "+d_name+" "+str(d_text))
		try:
			q_dict={'id':int(q_id),"text":q_text,"grounds":q_grounds,"ground_ids":augmented_d_ids}
		except:
			q_dict={'id':str(q_id),"text":q_text,"grounds":q_grounds,"ground_ids":augmented_d_ids}
		json.dump(q_dict, fout)
		fout.write('\n')
print("over!")