
"""Module for computing performance metrics

"""
import math
import numbers
from pathlib import Path
from unittest import result
import ipdb
import numpy as np
import torch
import scipy.stats
from sklearn.metrics import average_precision_score
import ipdb
import pdb

def i2t(sims):
    # sims [N, 5N] im-t
    npts = sims.shape[0]

    ranks = np.zeros(npts)
    top1 = np.zeros(npts)
    for index in range(npts):
        inds = np.argsort(sims[index])[::-1]
        # Score
        rank = 1e20
        for i in range(5 * index, 5 * index + 5, 1):
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank
        top1[index] = inds[0]

    # Compute metrics
    metrics = {}
    metrics["R1"] = 100 * len(np.where(ranks < 1)[0]) / len(ranks)
    metrics["R5"] = 100 * len(np.where(ranks < 5)[0]) / len(ranks)
    metrics["R10"] = 100 * len(np.where(ranks < 10)[0]) / len(ranks)
    metrics["R50"] = 100 * len(np.where(ranks < 50)[0]) / len(ranks)
    metrics["MedR"] = np.floor(np.median(ranks)) + 1
    metrics["MeanR"] = ranks.mean() + 1 

    return metrics

def t2i(sims):
    # sims [5N, N] t-im
    npts = sims.shape[1]
    ranks = np.zeros(5 * npts)
    top1 = np.zeros(5 * npts)

    for index in range(npts):
        for i in range(5):
            inds = np.argsort(sims[5 * index + i])[::-1]
            ranks[5 * index + i] = np.where(inds == index)[0][0]
            top1[5 * index + i] = inds[0]

    # Compute metrics
    metrics = {}
    metrics["R1"] = 100 * len(np.where(ranks < 1)[0]) / len(ranks)
    metrics["R5"] = 100 * len(np.where(ranks < 5)[0]) / len(ranks)
    metrics["R10"] = 100 * len(np.where(ranks < 10)[0]) / len(ranks)
    metrics["R50"] = 100 * len(np.where(ranks < 50)[0]) / len(ranks)
    metrics["MedR"] = np.floor(np.median(ranks)) + 1
    metrics["MeanR"] = ranks.mean() + 1 

    return metrics

def coco_t2v_metrics(sims, fold=1):
    assert sims.ndim == 2, "expected a matrix"
    # sims = sims[:, ::5] # [25000, 5000]
    if fold == 1:
        return t2i(sims)
    else:
        # sim_list = [sims[i::fold,:] for i in range(fold)]
        metrics = {}
        for i in range(fold):
            sim = sims[i * 5000:(i+1) * 5000, i*1000:(i+1)*1000]
            r = t2i(sim)
            for k, v in r.items():
                if k in metrics:
                    metrics[k] += v
                else:
                    metrics[k] = v
        for k, v in metrics.items():
            metrics[k] = v / fold
        return metrics

def coco_v2t_metrics(sims, fold=1):
    # switch axes of text and video
    sims = sims.T
    # sims = sims[::5, :] # [5000, 25000]
    if fold == 1:
        return i2t(sims)
    else:
        metrics = {}
        for i in range(fold):
            sim = sims[i*1000:(i+1)*1000, i * 5000:(i+1) * 5000] 
            r = i2t(sim)
            for k, v in r.items():
                if k in metrics:
                    metrics[k] += v
                else:
                    metrics[k] = v
        for k, v in metrics.items():
            metrics[k] = v / fold
        return metrics

def t2v_metrics(sims, query_masks=None):
    """Compute retrieval metrics from a similiarity matrix.

    Args:
        sims (th.Tensor): N x M matrix of similarities between embeddings, where
             x_{i,j} = <text_embd[i], vid_embed[j]>
        query_masks (th.Tensor): mask any missing queries from the dataset (two videos
             in MSRVTT only have 19, rather than 20 captions)

    Returns:
        (dict[str:float]): retrieval metrics
    """
    assert sims.ndim == 2, "expected a matrix"
    num_queries, num_vids = sims.shape
    dists = -sims
    sorted_dists = np.sort(dists, axis=1)

    # The indices are computed such that they slice out the ground truth distances
    # from the psuedo-rectangular dist matrix
    queries_per_video = num_queries // num_vids
    gt_idx = [[np.ravel_multi_index([ii, jj], (num_queries, num_vids))
               for ii in range(jj * queries_per_video, (jj + 1) * queries_per_video)]
              for jj in range(num_vids)]
    gt_idx = np.array(gt_idx)
    gt_dists = dists.reshape(-1)[gt_idx.reshape(-1)]
    gt_dists = gt_dists[:, np.newaxis]
    rows, cols = np.where((sorted_dists - gt_dists) == 0)  # find column position of GT

    # --------------------------------
    # NOTE: Breaking ties
    # --------------------------------
    # We sometimes need to break ties (in general, these should occur extremely rarely,
    # but there are pathological cases when they can distort the scores, such as when
    # the similarity matrix is all zeros). Previous implementations (e.g. the t2i
    # evaluation function used
    # here: https://github.com/niluthpol/multimodal_vtt/blob/master/evaluation.py and
    # here: https://github.com/linxd5/VSE_Pytorch/blob/master/evaluation.py#L87) generally
    # break ties "optimistically".  However, if the similarity matrix is constant this
    # can evaluate to a perfect ranking. A principled option is to average over all
    # possible partial orderings implied by the ties. See # this paper for a discussion:
    #    McSherry, Frank, and Marc Najork,
    #    "Computing information retrieval performance measures efficiently in the presence
    #    of tied scores." European conference on information retrieval. Springer, Berlin,
    #    Heidelberg, 2008.
    # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.145.8892&rep=rep1&type=pdf

    break_ties = "optimistically"
    #break_ties = "averaging"

    if rows.size > num_queries:
        assert np.unique(rows).size == num_queries, "issue in metric evaluation"
        if break_ties == "optimistically":
            _, idx = np.unique(rows, return_index=True)
            cols = cols[idx]
        elif break_ties == "averaging":
            # fast implementation, based on this code:
            # https://stackoverflow.com/a/49239335
            locs = np.argwhere((sorted_dists - gt_dists) == 0)

            # Find the split indices
            steps = np.diff(locs[:, 0])
            splits = np.nonzero(steps)[0] + 1
            splits = np.insert(splits, 0, 0)

            # Compute the result columns
            summed_cols = np.add.reduceat(locs[:, 1], splits)
            counts = np.diff(np.append(splits, locs.shape[0]))
            avg_cols = summed_cols / counts
            if False:
                print("Running slower code to verify rank averaging across ties")
                # slow, but more interpretable version, used for testing
                avg_cols_slow = [np.mean(cols[rows == idx]) for idx in range(num_queries)]
                assert np.array_equal(avg_cols, avg_cols_slow), "slow vs fast difference"
                print("passed num check")
            cols = avg_cols

    msg = "expected ranks to match queries ({} vs {}) "
    if cols.size != num_queries:
        import ipdb;
        ipdb.set_trace()
    assert cols.size == num_queries, msg

    if False:
        # overload mask to check that we can recover the scores for single-query
        # retrieval
        print("DEBUGGING MODE")
        query_masks = np.zeros_like(query_masks)
        query_masks[:, 0] = 1  # recover single query score

    if query_masks is not None:
        # remove invalid queries
        assert query_masks.size == num_queries, "invalid query mask shape"
        cols = cols[query_masks.reshape(-1).astype(np.bool)]
        assert cols.size == query_masks.sum(), "masking was not applied correctly"
        # update number of queries to account for those that were missing
        num_queries = query_masks.sum()

    if False:
        # sanity check against old logic for square matrices
        gt_dists_old = np.diag(dists)
        gt_dists_old = gt_dists_old[:, np.newaxis]
        _, cols_old = np.where((sorted_dists - gt_dists_old) == 0)
        assert np.array_equal(cols_old, cols), "new metric doesn't match"

    return cols2metrics(cols, num_queries)


def v2t_metrics(sims, query_masks=None):
    """Compute retrieval metrics from a similiarity matrix.

    Args:
        sims (th.Tensor): N x M matrix of similarities between embeddings, where
             x_{i,j} = <text_embd[i], vid_embed[j]>
        query_masks (th.Tensor): mask any missing captions from the dataset

    Returns:
        (dict[str:float]): retrieval metrics

    NOTES: We find the closest "GT caption" in the style of VSE, which corresponds
    to finding the rank of the closest relevant caption in embedding space:
    github.com/ryankiros/visual-semantic-embedding/blob/master/evaluation.py#L52-L56
    """
    # switch axes of text and video
    sims = sims.T

    if False:
        # experiment with toy example
        sims = np.ones((3, 3))
        sims[0, 0] = 2
        sims[1, 1:2] = 2
        sims[2, :] = 2
        query_masks = None

    assert sims.ndim == 2, "expected a matrix"
    num_queries, num_caps = sims.shape
    dists = -sims
    caps_per_video = num_caps // num_queries
    break_ties = "averaging"

    MISSING_VAL = 1E8
    query_ranks = []
    for ii in range(num_queries):
        row_dists = dists[ii, :]
        if query_masks is not None:
            # Set missing queries to have a distance of infinity.  A missing query
            # refers to a query position `n` for a video that had less than `n`
            # captions (for example, a few MSRVTT videos only have 19 queries)
            row_dists[np.logical_not(query_masks.reshape(-1))] = MISSING_VAL

        # NOTE: Using distance subtraction to perform the ranking is easier to make
        # deterministic than using argsort, which suffers from the issue of defining
        # "stability" for equal distances.  Example of distance subtraction code:
        # github.com/antoine77340/Mixture-of-Embedding-Experts/blob/master/train.py
        sorted_dists = np.sort(row_dists)

        min_rank = np.inf
        for jj in range(ii * caps_per_video, (ii + 1) * caps_per_video):
            if row_dists[jj] == MISSING_VAL:
                # skip rankings of missing captions
                continue
            ranks = np.where((sorted_dists - row_dists[jj]) == 0)[0]
            if break_ties == "optimistically":
                rank = ranks[0]
            elif break_ties == "averaging":
                # NOTE: If there is more than one caption per video, its possible for the
                # method to do "worse than chance" in the degenerate case when all
                # similarities are tied.  TODO(Samuel): Address this case.
                rank = ranks.mean()
            if rank < min_rank:
                min_rank = rank
        query_ranks.append(min_rank)
    query_ranks = np.array(query_ranks)

    # sanity check against old version of code
    if False:
        sorted_dists = np.sort(dists, axis=1)
        gt_dists_old = np.diag(dists)
        gt_dists_old = gt_dists_old[:, np.newaxis]
        rows_old, cols_old = np.where((sorted_dists - gt_dists_old) == 0)
        if rows_old.size > num_queries:
            _, idx = np.unique(rows_old, return_index=True)
            cols_old = cols_old[idx]
        num_diffs = (1 - (cols_old == query_ranks)).sum()
        msg = f"new metric doesn't match in {num_diffs} places"
        assert np.array_equal(cols_old, query_ranks), msg

        # visualise the distance matrix
        import sys
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python"))
        from zsvision.zs_iterm import zs_dispFig  # NOQA
        plt.matshow(dists)
        zs_dispFig()

    return cols2metrics(query_ranks, num_queries)


def retrieval_as_classification(sims, query_masks=None):
    """Compute classification metrics from a similiarity matrix.
    """
    assert sims.ndim == 2, "expected a matrix"

    # switch axes of query-labels and video
    sims = sims.T
    query_masks = query_masks.T
    dists = -sims
    num_queries, num_labels = sims.shape
    break_ties = "averaging"

    query_ranks = []
    for ii in range(num_queries):
        row_dists = dists[ii, :]

        # NOTE: Using distance subtraction to perform the ranking is easier to make
        # deterministic than using argsort, which suffers from the issue of defining
        # "stability" for equal distances.  Example of distance subtraction code:
        # github.com/antoine77340/Mixture-of-Embedding-Experts/blob/master/train.py
        sorted_dists = np.sort(row_dists)

        # min_rank = np.inf
        label_ranks = []
        for gt_label in np.where(query_masks[ii, :])[0]:
            ranks = np.where((sorted_dists - row_dists[gt_label]) == 0)[0]
            if break_ties == "optimistically":
                rank = ranks[0]
            elif break_ties == "averaging":
                # NOTE: If there is more than one caption per video, its possible for the
                # method to do "worse than chance" in the degenerate case when all
                # similarities are tied.  TODO(Samuel): Address this case.
                rank = ranks.mean()
            else:
                raise ValueError(f"unknown tie-breaking method: {break_ties}")
            label_ranks.append(rank)
        # Avoid penalising for assigning higher similarity to other gt labels. This is
        # done by subtracting out the better ranked query labels.  Note that this step
        # introduces a slight skew in favour of videos with lots of labels.  We can
        # address this later with a normalisation step if needed.
        label_ranks = [x - idx for idx, x in enumerate(label_ranks)]

        # Include all labels in the final calculation
        query_ranks.extend(label_ranks)
    query_ranks = np.array(query_ranks)

    # sanity check against old version of code
    if False:
        # visualise the distance matrix
        import sys
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python"))
        from zsvision.zs_iterm import zs_dispFig  # NOQA
        # plt.matshow(dists)
        # zs_dispFig()
        plt.hist(query_ranks, bins=313, alpha=0.5)
        plt.grid()
        zs_dispFig()
        import ipdb;
        ipdb.set_trace()

    return cols2metrics(query_ranks, num_queries=len(query_ranks))


def evaluate_qa(results, label2ans, qid2data):
    """
    Args:
        results: list(dict),
            each dict is
            {
                "question_id": int,
                "answer": int or float, either answer_idx (int)
            }
    Returns:
        MSRVTT-QA score
    """
    preds = []
    gts = []
    # for MSRVTTQA
    answer_types = []
    answer_type2idx =  {k: idx for idx, k in enumerate(["what", "who", "how", "where", "when", "object", "number", "color", "location"])}

    qid2pred_ans = {r["question_id"]: r["answer"] for r in results}
    qid2pred_ans = {k: label2ans[v] for k, v in qid2pred_ans.items()}

    for qid, pred_ans in qid2pred_ans.items():
        preds.append(pred_ans)

        gt_data = qid2data[qid]
        gt_ans = gt_data["answer"]
        answer_types.append(answer_type2idx[gt_data["answer_type"]])
        gts.append(gt_ans)

    preds = np.array(preds)
    gts = np.array(gts)
    metrics = dict()
    # preds and gts are array of strings
    metrics["overall_acc"] = float(np.mean(preds == gts))
    answer_types = np.array(answer_types)
    ratios = dict()
    for ans_type, ans_type_idx in answer_type2idx.items():
        answer_type_mask = answer_types == ans_type_idx
        answer_type_corrects = (
                preds[answer_type_mask] == gts[answer_type_mask])
        print(ans_type, np.sum(answer_type_corrects), len(answer_type_corrects))
        metrics[f"{ans_type}_acc"] = float(
            np.mean(answer_type_corrects)) if len(answer_type_corrects) != 0 else 0
        ratios[f"{ans_type}_ratio"] = [
            1. * len(answer_type_corrects) / len(answer_types),
            len(answer_type_corrects)]
    metrics["ratios"] = ratios
    return metrics


def evaluate_mc(pred_id2answer, gt_id2answer):
    """
    Args:
        pred_id2answer: dict, {id: pred_answer_idx}
        force_same: bool, if True, the predictions should contain the same set of ids as the GT.
    """
    gt_ids = list(gt_id2answer.keys())
    pred_ids = list(pred_id2answer.keys())
    print(f"There are {len(gt_ids)} gt ids, {len(pred_ids)} pred ids.")
    # if force_same:
    #     assert set(gt_ids) == set(pred_ids)
    #     shared_ids = list(set(gt_ids) & set(pred_ids))
    # else:
    #     shared_ids = pred_ids
    # assert set(gt_ids) == set(pred_ids)
    shared_ids = pred_ids

    gt_answers = np.array([gt_id2answer[k] for k in shared_ids])
    pred_answers = np.array([pred_id2answer[k] for k in shared_ids])
    acc = np.mean(gt_answers == pred_answers)
    return dict(mc_accuracy=f"{100 * acc:.2f}")


def cols2metrics(cols, num_queries):
    metrics = {}
    metrics["R1"] = 100 * float(np.sum(cols == 0)) / num_queries
    metrics["R5"] = 100 * float(np.sum(cols < 5)) / num_queries
    metrics["R10"] = 100 * float(np.sum(cols < 10)) / num_queries
    metrics["R50"] = 100 * float(np.sum(cols < 50)) / num_queries
    metrics["MedR"] = np.median(cols) + 1
    metrics["MeanR"] = np.mean(cols) + 1
    stats = [metrics[x] for x in ("R1", "R5", "R10")]
    metrics["geometric_mean_R1-R5-R10"] = scipy.stats.mstats.gmean(stats)
    return metrics


def mean_average_precision(sims, query_masks=None):
    ap_meter = APMeter()
    ap_meter.add(output=sims.T, target=query_masks.T)
    return {"mAP": ap_meter.value().mean()}

def acc(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        assert pred.shape[0] == len(target)
        correct = 0
        correct += torch.sum(pred == target).item()
    return correct / len(target)


def my_metric2(output, target, k=3):
    with torch.no_grad():
        pred = torch.topk(output, k, dim=1)[1]
        assert pred.shape[0] == len(target)
        correct = 0
        for i in range(k):
            correct += torch.sum(pred[:, i] == target).item()
    return correct / len(target)


def video_precision(output, target):
    """ percentage of videos which have been aligned to a matching text pair"""
    assert output.shape[0] == target.shape[0]
    assert output.shape[2] == target.shape[2] == 2

    correct = 0
    for bout, btarg in zip(output, target):
        for pair in bout:
            eq = torch.eq(pair, btarg)
            if torch.logical_and(eq[:, 0], eq[:, 1]).any():
                correct += 1
    return correct / (target.shape[0] * target.shape[1])

def video_precision_adj(output, target):
    """ adjusts the video precision metric by ignoring videos which have no aligning text."""
    assert output.shape[0] == target.shape[0]
    assert output.shape[2] == target.shape[2] == 2

    assert output.shape[0] == target.shape[0]
    assert output.shape[2] == target.shape[2] == 2

    correct = 0
    for bout, btarg in zip(output, target):
        for pair in bout:
            eq = torch.eq(pair, btarg)
            if torch.logical_and(eq[:, 0], eq[:, 1]).any():
                correct += 1
    denom = len(target[:, :, 0].unique())

    return correct / denom

def video_precision_adj(output, target):
    """ adjusts the video precision metric by ignoring videos which have no aligning text."""
    assert output.shape[0] == target.shape[0]
    assert output.shape[2] == target.shape[2] == 2

    assert output.shape[0] == target.shape[0]
    assert output.shape[2] == target.shape[2] == 2

    correct = 0
    for bout, btarg in zip(output, target):
        for pair in bout:
            eq = torch.eq(pair, btarg)
            if torch.logical_and(eq[:, 0], eq[:, 1]).any():
                correct += 1
    denom = len(target[:, :, 0].unique())

    return correct / denom