
from colbert.infra.run import Run
from colbert.data.collection import Collection
import os
import sys
import git
import tqdm
import ujson
import random

from argparse import ArgumentParser
from multiprocessing import Pool

from colbert.utils.utils import groupby_first_item, print_message
from utility.utils.qa_loaders import load_qas_, load_collection_
from utility.utils.save_metadata import format_metadata, get_metadata
from utility.evaluate.annotate_EM_helpers import *

from colbert.data.ranking import Ranking


class AnnotateEM:
    def __init__(self, collection, qas):
        # TODO: These should just be Queries! But Queries needs to support looking up answers as qid2answers below.
        qas = load_qas_(qas)
        collection = Collection.cast(collection)  # .tolist() #load_collection_(collection, retain_titles=True)

        self.parallel_pool = Pool(30)

        print_message('#> Tokenize the answers in the Q&As in parallel...')
        qas = list(self.parallel_pool.map(tokenize_all_answers, qas))

        qid2answers = {qid: tok_answers for qid, _, tok_answers in qas}
        assert len(qas) == len(qid2answers), (len(qas), len(qid2answers))

        self.qas, self.collection = qas, collection
        self.qid2answers = qid2answers

    def annotate(self, ranking):
        rankings = Ranking.cast(ranking)

        # print(len(rankings), rankings[0])

        print_message('#> Lookup passages from PIDs...')
        expanded_rankings = [(qid, pid, rank, self.collection[pid], self.qid2answers[qid])
                             for qid, pid, rank, *_ in rankings.tolist()]

        print_message('#> Assign labels in parallel...')
        labeled_rankings = list(self.parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings)))

        # Dump output.
        self.qid2rankings = groupby_first_item(labeled_rankings)

        self.num_judged_queries, self.num_ranked_queries = check_sizes(self.qid2answers, self.qid2rankings)

        # Evaluation metrics and depths.
        self.success, self.counts = self._compute_labels(self.qid2answers, self.qid2rankings)

        print(rankings.provenance(), self.success)

        return Ranking(data=self.qid2rankings, provenance=("AnnotateEM", rankings.provenance()))

    def _compute_labels(self, qid2answers, qid2rankings):
        cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all']
        success = {cutoff: 0.0 for cutoff in cutoffs}
        counts = {cutoff: 0.0 for cutoff in cutoffs}

        for qid in qid2answers:
            if qid not in qid2rankings:
                continue

            prev_rank = 0  # ranks should start at one (i.e., and not zero)
            labels = []

            for pid, rank, label in qid2rankings[qid]:
                assert rank == prev_rank+1, (qid, pid, (prev_rank, rank))
                prev_rank = rank

                labels.append(label)

            for cutoff in cutoffs:
                if cutoff != 'all':
                    success[cutoff] += sum(labels[:cutoff]) > 0
                    counts[cutoff] += sum(labels[:cutoff])
                else:
                    success[cutoff] += sum(labels) > 0
                    counts[cutoff] += sum(labels)

        return success, counts

    def save(self, new_path):
        print_message("#> Dumping output to", new_path, "...")

        Ranking(data=self.qid2rankings).save(new_path)

        # Dump metrics.
        with Run().open(f'{new_path}.metrics', 'w') as f:
            d = {'num_ranked_queries': self.num_ranked_queries, 'num_judged_queries': self.num_judged_queries}

            extra = '__WARNING' if self.num_judged_queries != self.num_ranked_queries else ''
            d[f'success{extra}'] = {k: v / self.num_judged_queries for k, v in self.success.items()}
            d[f'counts{extra}'] = {k: v / self.num_judged_queries for k, v in self.counts.items()}
            # d['arguments'] = get_metadata(args)  # TODO: Need arguments...

            f.write(format_metadata(d) + '\n')


if __name__ == '__main__':
    r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.50.02/ranking.tsv'
    r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.59.37/ranking.tsv'
    r = sys.argv[1]

    a = AnnotateEM(collection='/future/u/okhattab/root/unit/data/NQ-mini/collection.tsv',
                   qas='/future/u/okhattab/root/unit/data/NQ-mini/dev/qas.json')
    a.annotate(ranking=r)
