
import os
import argparse
import collections
import logging
import sys
import numpy as np
import pandas as pd
import scipy.sparse as smat
from pecos.utils import smat_util
from sup_con_xmc import base_utils as cpb_utils


logger = logging.getLogger(__name__)


TOPK_LIST = [1, 3, 5, 10, 50, 100]


# From https://github.com/kunaldahiya/pyxclib/blob/8d9af7093c32e258c1340862868ff0856a7fc235/xclib/evaluation/xc_metrics.py#L195C1-L222C25
def get_inv_prop(Y_trn, dataset_name):
    if "amazon" in dataset_name.lower(): A = 0.6; B = 2.6
    elif "wiki" in dataset_name.lower() and "wikiseealso" not in dataset_name.lower(): A = 0.5; B = 0.4
    else : A = 0.55; B = 1.5
    num_instances, _ = Y_trn.shape
    freqs = np.ravel(np.sum(Y_trn, axis=0))
    C = (np.log(num_instances)-1)*np.power(B+1, A)
    wts = 1.0 + C*np.power(freqs+B, -A)
    return np.ravel(wts)

def load_filter_mat(fname, shape):
    filter_mat = None
    if os.path.exists(fname):
        temp = np.fromfile(fname, sep=' ').astype(int)
        temp = temp.reshape(-1, 2).T
        filter_mat = smat.coo_matrix((np.ones(temp.shape[1]), (temp[0], temp[1])), shape).tocsr()
    return filter_mat

def _filter(score_mat, filter_mat, copy=True):
    if filter_mat is None:
        return score_mat
    if copy:
        score_mat = score_mat.copy()

    temp = filter_mat.tocoo()
    score_mat[temp.row, temp.col] = 0
    del temp
    score_mat = score_mat.tocsr()
    score_mat.eliminate_zeros()
    return score_mat


def compute_old_metrics(Yt, Yp, eval_topk=100):
    inst_metrics = []
    macr_metrics = []

    # instance-wise P/R/F@k, only call once
    inst_res = smat_util.Metrics.generate(Yt, Yp, topk=eval_topk)
    logger.info("P@1\tP@3\tP@5\tR@10\tR@50\tR@100")
    logger.info("{:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f}".format(
        inst_res.prec[1 - 1] * 100.,
        inst_res.prec[3 - 1] * 100.,
        inst_res.prec[5 - 1] * 100.,
        inst_res.recall[10 - 1] * 100.,
        inst_res.recall[50 - 1] * 100.,
        inst_res.recall[100 - 1] * 100.,
    ))


def sp_rank(csr):
    rank_mat = smat_util.sorted_csr(csr)
    rank_mat.data = np.concatenate([np.arange(1, x+1) for x in rank_mat.getnnz(1)])
    return rank_mat

def _topk(rank_mat, K, inplace=False):
    topk_mat = rank_mat if inplace else rank_mat.copy()
    topk_mat.data[topk_mat.data > K] = 0
    topk_mat.eliminate_zeros()
    return topk_mat

def _compute_xmc_metrics(rank_intrsxn_mat, true_mat, K=TOPK_LIST, inv_prop=None):
    K = sorted(K, reverse=True)
    topk_intrsxn_mat = rank_intrsxn_mat.copy()
    res = {'P': {}, 'R': {}, 'nDCG': {}, 'MRR': {}}
    if inv_prop is not None:
        res['PSP'] = {}
        psp_true_mat = true_mat.copy()
        psp_true_mat.data[:] = 1
        psp_true_mat.data *= inv_prop[psp_true_mat.indices]

    for k in K:
        topk_intrsxn_mat = _topk(topk_intrsxn_mat, k, inplace=True)
        res['R'][k] = (topk_intrsxn_mat.getnnz(1)/true_mat.getnnz(1)).mean()*100.0
        res['P'][k] = (topk_intrsxn_mat.getnnz(1)/k).mean()*100.0

        temp_topk_intrsxn_mat = topk_intrsxn_mat.copy()
        temp_topk_intrsxn_mat.data = 1/np.log2(1+temp_topk_intrsxn_mat.data)
        dcg_coeff = 1/np.log2(np.arange(k)+2)
        dcg_coeff_cumsum = np.cumsum(dcg_coeff, 0)
        dcg_denom = dcg_coeff_cumsum[np.minimum(true_mat.getnnz(1), k)-1]
        res['nDCG'][k] = (temp_topk_intrsxn_mat.sum(1).squeeze()/dcg_denom).mean()*100.0

        temp_topk_intrsxn_mat = topk_intrsxn_mat.copy()
        temp_topk_intrsxn_mat.data = 1/temp_topk_intrsxn_mat.data
        max_rr = temp_topk_intrsxn_mat.max(axis=1).toarray().ravel()
        res['MRR'][k] = max_rr.mean()*100.0

        if inv_prop is not None:
            temp_topk_intrsxn_mat = topk_intrsxn_mat.copy()
            temp_topk_intrsxn_mat.data[:] = 1
            temp_topk_intrsxn_mat.data *= inv_prop[temp_topk_intrsxn_mat.indices]
            psp_topk_true_mat = smat_util.sorted_csr(psp_true_mat, k)
            psp_denom = (psp_topk_true_mat.sum(1)/k).mean()
            res['PSP'][k] = (temp_topk_intrsxn_mat.sum(1)/k).mean()*100.0 / psp_denom

    return res

def compute_xmc_metrics(Yt, Yp, inv_prop=None, K=100):
    Ks = np.array([1,3,5,10,50,100], dtype=np.int32)
    if K <= 100: Ks = Ks[~(Ks > K)]
    else: Ks = np.concatenate([Ks, np.array([100*i for i in range(2, 1+(K//100))], dtype=np.int32)])
    Yt = Yt.copy().tocsr()
    Yt.data[:] = 1
    rank_mat = sp_rank(Yp)
    rank_intrsxn_mat = rank_mat.multiply(Yt)
    xmc_eval_metrics = pd.DataFrame(
        _compute_xmc_metrics(rank_intrsxn_mat, Yt, K=Ks.tolist(), inv_prop=inv_prop)
    ).round(2).transpose() 
    df = xmc_eval_metrics.stack().to_frame().transpose()
    df.columns = [f'{col[0]}@{col[1]}' for col in df.columns.values]
    psp_cols = ['PSP@1', 'PSP@3', 'PSP@5'] if 'PSP@1' in df.columns else []
    recall_cols = [f'R@{k}' for k in Ks[Ks >= 10]]
    df = df[[*['P@1', 'P@3', 'P@5', 'nDCG@1', 'nDCG@3', 'nDCG@5', 'MRR@10'], *psp_cols, *recall_cols]].round(2)

    logger.info("\n{}".format(df.to_csv(sep='\t', index=False)))


def main(args):
    cpb_utils.setup_logging_config(level=logging.INFO)
    Y = smat_util.load_matrix(args.train_path).tocsr().astype(np.float32)
    inv_prop = get_inv_prop(Y, args.data_name)

    Yt = smat_util.load_matrix(args.true_path).tocsr().astype(np.float32)
    Yp = smat_util.load_matrix(args.pred_path).tocsr().astype(np.float32)
    compute_xmc_metrics(Yt, Yp, inv_prop=inv_prop)

    if args.filter_path:
        filter_mat = load_filter_mat(args.filter_path, Yt.shape)
        Yp_filtered = _filter(Yp, filter_mat, copy=True)
        compute_xmc_metrics(Yt, Yp_filtered, inv_prop=inv_prop)


def parse_arguments():
    """Parse evaluation arguments"""

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-d", "--data-name",
        type=str,
        required=True,
        metavar="STR",
        help="Dataset Name, which determines hyper-parameters of inv_prop estimation",
    )

    parser.add_argument(
        "-t", "--train-path",
        type=str,
        required=True,
        metavar="PATH",
        help="path to the train file of with ground truth output for inv_prop estimation",
    )

    parser.add_argument(
        "-y", "--true-path",
        type=str,
        required=True,
        metavar="PATH",
        help="path to the file of with ground truth output (CSR: nr_insts * nr_items)",
    )

    parser.add_argument(
        "-p", "--pred-path",
        type=str,
        required=True,
        metavar="PATH",
        help="path to the file of predicted output (CSR: nr_insts * nr_items)",
    )

    parser.add_argument(
        "-f", "--filter-path",
        type=str,
        default=None,
        metavar="PATH",
        help="path to the file of reciprocal-removal pairs (for LF-XMC datasets)",
    )

    return parser


if __name__ == "__main__":
    parser = parse_arguments()
    args = parser.parse_args()
    main(args)

