import json
import numpy as np
import sys
import spacy
from spacy.lang.en import English
from tqdm import tqdm

# spacy_tagger = spacy.load("en_core_web_sm", disable=["parser", "tagger", "ner", "lemmatizer"])
nlp = English()
nlp.add_pipe('sentencizer')

model_name = sys.argv[1]
test_data = sys.argv[2]
out_dir = sys.argv[3]
print(model_name)
print(test_data)
print(out_dir)
my_pred = json.load(open(f'/outputs/{model_name}/pred/{test_data}.pred'))

top_k = 20


my_target = []
avg_len = []
for qid, pred in tqdm(enumerate(my_pred.values())):
    my_dict = {"id": str(qid), "question": None, "answers": [], "ctxs": []}

    # truncate
    pred = {key: val[:top_k] if key in ['evidence', 'title', 'se_pos', 'prediction'] else val for key, val in pred.items()}

    # TODO: need to add id for predictions.pred
    my_dict["question"] = pred["question"]
    my_dict["answers"] = pred["answer"]
    pred["title"] = [titles[0] for titles in pred["title"]]

    # Restore original question + add the target
    '''
    if (gold_file[qid]['question'][:-1] if gold_file[qid]['question'].endswith('?') else gold_file[qid]['question']) != pred['question']:
        import pdb; pdb.set_trace()
    my_dict["question"] = gold_file[qid]["question"]
    my_dict["target"] = gold_file[qid]["target"]
    '''


    assert len(set(pred["evidence"])) == len(pred["evidence"]) == len(pred["title"])       # All unique passages
    assert all(pr in evd for pr, evd in zip(pred["prediction"], pred["evidence"]))         # prediction included
    if not(len(pred["prediction"]) == len(pred["evidence"]) == len(pred["title"]) == top_k):  # Top 100 passages
        assert len(pred["prediction"]) == len(pred["evidence"]) == len(pred["title"]) < top_k, \
            (len(pred["prediction"]), len(pred["evidence"]), len(pred["title"]))
        print(len(pred["prediction"]), len(pred["evidence"]))

        pred["evidence"] += [pred["evidence"][-1]] * (top_k - len(pred["prediction"]))
        pred["title"] += [pred["title"][-1]] * (top_k - len(pred["prediction"]))
        pred["se_pos"] += [pred["se_pos"][-1]] * (top_k - len(pred["prediction"]))
        pred["prediction"] += [pred["prediction"][-1]] * (top_k - len(pred["prediction"]))

        assert len(pred["prediction"]) == len(pred["evidence"]) == len(pred["title"]) == top_k

    # Used for markers
    START = '<p_start>'
    END = '<p_end>'
    se_idxs = [[se_pos[0], max(se_pos[0], se_pos[1])] for se_pos in pred["se_pos"]]

    # Use sentence + marker
    '''
    sents = [[(X.text, X[0].idx) for X in nlp(evidence).sents] for evidence in pred['evidence']]
    sent_idxs = [
        sorted(set([sum(np.array([st[1] for st in sent]) <= se_idx[0]) - 1] + [sum(np.array([st[1] for st in sent]) <= se_idx[1]-1) - 1]))
        for se_idx, sent in zip(se_idxs, sents)
    ]
    se_idxs = [[se_pos[0]-sent[sent_idx[0]][1], se_pos[1]-sent[sent_idx[0]][1]] for se_pos, sent_idx, sent in zip(se_idxs, sent_idxs, sents)]
    if not all(pred.replace(' ', '') in ' '.join([sent[sidx][0] for sidx in range(sent_idx[0], sent_idx[-1]+1)]).replace(' ', '')
               for pred, sent, sent_idx in zip(pred['prediction'], sents, sent_idxs)):
        import pdb; pdb.set_trace()
        pass

    my_dict["ctxs"] = [
        {"title": title, "text": ' '.join(' '.join([sent[sidx][0] for sidx in range(sent_idx[0], sent_idx[-1]+1)]).split()[:98])} # TODO this misses last tokens
        for title, sent, sent_idx in zip(pred["title"], sents, sent_idxs)
    ]
    '''

    # Use raw outputs
    my_dict["ctxs"] = [
        {"title": title, "text": ' '.join(evd.split()[:100])}
        for evd, title in zip(pred["evidence"], pred["title"])
    ]

    '''
    my_dict["ctxs"] = [
        {"title": ctx["title"], "text": ctx["text"][:se[0]] + f"{START} " + ctx["text"][se[0]:se[1]] + f" {END}" + ctx["text"][se[1]:]}
        for ctx, se in zip(my_dict["ctxs"], se_idxs)
    ]
    '''

    my_target.append(my_dict)
    avg_len += [len(ctx['text'].split()) for ctx in my_dict["ctxs"]]
    assert len(my_dict["ctxs"]) == top_k
    assert all(len(ctx['text'].split()) <= 100 for ctx in my_dict["ctxs"])

print(f"avg ctx len={sum(avg_len)/len(avg_len):.2f} for {len(my_pred)} preds")

recall_path = f"{out_dir}/{model_name}_{test_data}_top100_psg.json"
print(f"dump to {recall_path}")
json.dump(my_target, open(recall_path, 'w'), indent=4)
