# general
import argparse
from logzero import logger

import os, sys, time, os.path as osp
import math
import numpy as np

#jupyter
import pandas as pd
import torch
import numpy as np
from collections import defaultdict
import json

# scipy
from scipy.sparse import csr_matrix
import scipy as sp
import scipy.sparse as smat

import torch
import torch.nn as nn
import torch.nn.functional as F

import scipy.sparse as smat
from scipy.sparse import csr_matrix

import collections

class Metrics(collections.namedtuple("Metrics", ["prec", "recall"])):
    __slots__ = ()

    def __str__(self):
        fmt = lambda key: " ".join("{:4.2f}".format(100 * v) for v in getattr(self, key)[:])
        return "\n".join("{:7}= {}".format(key, fmt(key)) for key in self._fields)

    @classmethod
    def default(cls):
        return cls(prec=[], recall=[])

    @classmethod
    def precision(cls, tY, pY, topk=10):
        total_matched = sp.zeros(topk, dtype=sp.uint64)
        recall = sp.zeros(topk, dtype=sp.float64)
        for i in range(tY.shape[0]):
            truth = tY.indices[tY.indptr[i]: tY.indptr[i + 1]]
            matched = sp.isin(pY[i], truth)
            cum_matched = sp.cumsum(matched, dtype=sp.uint64)
            total_matched[: len(cum_matched)] += cum_matched
            recall[: len(cum_matched)] += cum_matched / len(truth)
            if len(cum_matched) != 0:
                total_matched[len(cum_matched):] += cum_matched[-1]
                recall[len(cum_matched):] += cum_matched[-1] / len(truth)
        prec = total_matched / tY.shape[0] / sp.arange(1, topk + 1)
        recall = recall / tY.shape[0]
        return cls(prec=prec, recall=recall)

def get_top_pred(res, k=10):
    return np.argsort(res)[:, ::-1][:, :k]
def print_res(y_truth, base_top):
    metric = Metrics.precision(y_truth, base_top)
    print(f"p1 {metric.prec[0]*100:.2f} p3 {metric.prec[2]*100:.2f} "
      f"p5 {metric.prec[4]*100:.2f}")
def sigmoid(x):
    z = 1/(1 + np.exp(-x))
    return z

def ensemble(res_list):
    res = 0
    for r in res_list:
        res += sigmoid(r)
    return res / len(res_list)

def main(args, data_dir, **kwargs):
    y_truth = smat.load_npz(osp.join(data_dir, 'y.dev.npz'))
    data = osp.basename(data_dir)
    baseline_path = osp.join(args.baseline_result_dir, f'{data}_tmp/{data}.1.0--1.tmp')
    model_path = osp.join(args.result_dir, f'{args.model_name}')

    model_dic = np.load(osp.join(model_path, 'results.npy'), allow_pickle=True).item()
    baseline_dic = np.load(osp.join(baseline_path, 'results.npy'), allow_pickle=True).item()

    baseline_top = get_top_pred(baseline_dic['raw'])
    model_top = get_top_pred(model_dic['raw'])
    en_res = ensemble([model_dic['raw'], baseline_dic['raw']])
    en_top = get_top_pred(en_res)

    print_res(y_truth, baseline_top)
    print_res(y_truth, model_top)
    print_res(y_truth, en_top)
    return

if __name__ == '__main__':
    args = argparse.ArgumentParser(description='Preprocess')
    args.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    args.add_argument("--result_dir", default=None, type=str,
                      help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    args.add_argument("--baseline_result_dir", default=None, type=str,
                      help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    args.add_argument("--model_name", default=None, type=str,
                      help="The input data dir. Should contain the .tsv files (or other data files) for the task.")

    args = args.parse_args()
    main(args, **vars(args))