from argparse import ArgumentParser
from cmath import nan
import csv
from sklearn import neighbors
import pandas as pd
import json
import pickle
import numpy as np
import pdb


parser = ArgumentParser()
parser.add_argument('--q_att_file_path', type=str,required=True)
parser.add_argument('--trec_file', required=True)
parser.add_argument('--output', required=True)
parser.add_argument('--prf_k', type=int, default=10)
args = parser.parse_args()

with open(args.trec_file,'r') as rf:
	lines = rf.readlines()
cur_q, _, p_0, r_0, s_0, _ = lines[0].strip().split()
top_neighbors=[]
augmented_q_dict={}
for line in lines:
	q, _, p, r, s, _ = line.strip().split()
	if q==cur_q:
		top_neighbors.append(p)
	else:
		top_neighbors=top_neighbors[:args.prf_k]
		augmented_q_dict[str(cur_q)]=top_neighbors
		cur_q=q
		top_neighbors=[p]
top_neighbors=top_neighbors[:args.prf_k]
augmented_q_dict[str(cur_q)]=top_neighbors


with open(args.q_att_file_path,'rb') as fin:
	att_weights, qid_lookup_indices=pickle.load(fin)

qid_weight_dict={}
for i in range(len(qid_lookup_indices)):
	qid_weight_dict[str(qid_lookup_indices[i])]=att_weights[i]
reranked_q_dict={}
reranked_score_dict={}
for q_id in qid_weight_dict.keys():
	qid_weight_dict[q_id]=qid_weight_dict[q_id][1:]
	rank=np.argsort(-qid_weight_dict[q_id])
	reranked_neighbors=[]
	reranked_scores=[]
	for i in range(len(rank)):
		reranked_neighbors.append(augmented_q_dict[q_id][rank[i]])
		reranked_scores.append(qid_weight_dict[q_id][rank[i]])
	reranked_q_dict[q_id]=reranked_neighbors
	reranked_score_dict[q_id]=reranked_scores
with open(args.output, 'w') as fout:
	for q_id in reranked_q_dict.keys():
		docs=reranked_q_dict[q_id]
		for i in range(len(docs)):
			fout.write("\t".join([q_id,q_id,str(docs[i]),str(i+1),str(reranked_score_dict[q_id][i]),'rerank']))
			fout.write("\n")
