# coding=utf-8
# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import numpy as np
import pandas as pd


class MetricsDAG(object):
    """
    Compute various accuracy metrics for B_est.
    true positive(TP): an edge estimated with correct direction.
    true nagative(TN): an edge that is neither in estimated graph nor in true graph.
    false positive(FP): an edge that is in estimated graph but not in the true graph.
    false negative(FN): an edge that is not in estimated graph but in the true graph.
    reverse = an edge estimated with reversed direction.

    fdr: (reverse + FP) / (TP + FP)
    tpr: TP/(TP + FN)
    fpr: (reverse + FP) / (TN + FP)
    shd: undirected extra + undirected missing + reverse
    nnz: TP + FP
    precision: TP/(TP + FP)
    recall: TP/(TP + FN)
    F1: 2*(recall*precision)/(recall+precision)
    gscore: max(0, (TP-FP))/(TP+FN), A score ranges from 0 to 1

    Parameters
    ----------
    B_est: np.ndarray
        [d, d] estimate, {0, 1, -1}, -1 is undirected edge in CPDAG.
    B_true: np.ndarray
        [d, d] ground truth graph, {0, 1}.
    """

    def __init__(self, B_est, B_true):
        
        if not isinstance(B_est, np.ndarray):
            raise TypeError("Input B_est is not numpy.ndarray!")

        if not isinstance(B_true, np.ndarray):
            raise TypeError("Input B_true is not numpy.ndarray!")

        self.B_est = copy.deepcopy(B_est)
        self.B_true = copy.deepcopy(B_true)

        self.metrics = MetricsDAG._count_accuracy(self.B_est, self.B_true)

    @staticmethod
    def _count_accuracy(B_est, B_true, decimal_num=4):
        """
        Parameters
        ----------
        B_est: np.ndarray
            [d, d] estimate, {0, 1, -1}, -1 is undirected edge in CPDAG.
        B_true: np.ndarray
            [d, d] ground truth graph, {0, 1}.
        decimal_num: int
            Result decimal numbers.

        Return
        ------
        metrics: dict
            fdr: float
                (reverse + FP) / (TP + FP)
            tpr: float
                TP/(TP + FN)
            fpr: float
                (reverse + FP) / (TN + FP)
            shd: int
                undirected extra + undirected missing + reverse
            nnz: int
                TP + FP
            precision: float
                TP/(TP + FP)
            recall: float
                TP/(TP + FN)
            F1: float
                2*(recall*precision)/(recall+precision)
            gscore: float
                max(0, (TP-FP))/(TP+FN), A score ranges from 0 to 1
        """
        # temporal graph evaluate
        if B_true.shape[0] == B_true.shape[1]:
            d = B_true.shape[0]
            # linear index of nonzeros
            pred = np.flatnonzero(B_est != 0)
            cond = np.flatnonzero(B_true)
            cond_reversed = np.flatnonzero(B_true.T)
            # true pos
            true_pos = np.intersect1d(pred, cond, assume_unique=True)
            # false pos
            false_pos = np.setdiff1d(pred, cond, assume_unique=True)
            # reverse
            extra = np.setdiff1d(pred, cond, assume_unique=True)
            reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
            # compute ratio
            pred_size = len(pred)
            cond_neg_size = d * d - len(cond)
            fdr = float(len(false_pos)) / max(pred_size, 1)
            tpr = float(len(true_pos)) / max(len(cond), 1)
            fpr = float(len(false_pos)) / max(cond_neg_size, 1)
            # structural hamming distance
            pred_lower = np.flatnonzero(B_est + B_est.T)
            cond_lower = np.flatnonzero(B_true + B_true.T)
            extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
            missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
            shd = len(extra_lower) + len(missing_lower) + len(reverse)
        else:
            x_d, y_d = B_true.shape[0], B_true.shape[1]
            B_T_True = np.concatenate([B_true[i*y_d: (i+1)*y_d, :] for i in range(int(x_d /y_d))], axis=0)
            B_T_est = np.concatenate([B_est[i * y_d: (i + 1) * y_d, :] for i in range(int(x_d / y_d))], axis=0)
            # linear index of nonzeros
            pred = np.flatnonzero(B_est != 0)
            cond = np.flatnonzero(B_true)
            cond_reversed = np.flatnonzero(B_T_True)
            # true pos
            true_pos = np.intersect1d(pred, cond, assume_unique=True)
            # false pos
            false_pos = np.setdiff1d(pred, cond, assume_unique=True)
            # reverse
            extra = np.setdiff1d(pred, cond, assume_unique=True)
            reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
            # compute ratio
            pred_size = len(pred)
            cond_neg_size = x_d * y_d - len(cond)
            fdr = float(len(false_pos)) / max(pred_size, 1)
            tpr = float(len(true_pos)) / max(len(cond), 1)
            fpr = float(len(false_pos)) / max(cond_neg_size, 1)
            # structural hamming distance
            pred_lower = np.flatnonzero(B_est + B_T_est)
            cond_lower = np.flatnonzero(B_true + B_T_True)
            extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
            missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
            shd = len(extra_lower) + len(missing_lower) + len(reverse)

        W_p = pd.DataFrame(B_est)
        W_true = pd.DataFrame(B_true)
        gscore = MetricsDAG._cal_gscore(W_p, W_true)
        precision, recall, F1 = MetricsDAG._cal_precision_recall(W_p, W_true)
        mt = {'fdr': fdr, 'tpr': tpr, 'fpr': fpr, 'shd': shd, 'nnz': pred_size, 
              'precision': precision, 'recall': recall, 'F1': F1, 'gscore': gscore}
        for i in mt:
            mt[i] = round(mt[i], decimal_num)
        
        return mt

    @staticmethod
    def _cal_gscore(W_p, W_true):
        """
        Parameters
        ----------
        W_p: pd.DataDrame
            [d, d] estimate, {0, 1, -1}, -1 is undirected edge in CPDAG.
        W_true: pd.DataDrame
            [d, d] ground truth graph, {0, 1}.
        
        Return
        ------
        score: float
            max(0, (TP-FP))/(TP+FN), A score ranges from 0 to 1
        """
        
        num_true = W_true.sum(axis=1).sum()
        assert num_true!=0
        
        # true_positives
        num_tp =  (W_p + W_true).map(lambda elem:1 if elem==2 else 0).sum(axis=1).sum()
        # False Positives + Reversed Edges
        num_fn_r = (W_p - W_true).map(lambda elem:1 if elem==1 else 0).sum(axis=1).sum()
        score = np.max((num_tp-num_fn_r,0))/num_true
        
        return score

    @staticmethod
    def _cal_precision_recall(W_p, W_true):
        """
        Parameters
        ----------
        W_p: pd.DataDrame
            [d, d] estimate, {0, 1, -1}, -1 is undirected edge in CPDAG.
        W_true: pd.DataDrame
            [d, d] ground truth graph, {0, 1}.
        
        Return
        ------
        precision: float
            TP/(TP + FP)
        recall: float
            TP/(TP + FN)
        F1: float
            2*(recall*precision)/(recall+precision)
        """

        # assert(W_p.shape==W_true.shape and W_p.shape[0]==W_p.shape[1])
        TP = (W_p + W_true).map(lambda elem:1 if elem==2 else 0).sum(axis=1).sum()
        TP_FP = W_p.sum(axis=1).sum()
        TP_FN = W_true.sum(axis=1).sum()
        precision = TP/TP_FP
        recall = TP/TP_FN
        F1 = 2*(recall*precision)/(recall+precision)
        
        return precision, recall, F1
