import numpy as np
import torch
import torch.nn.functional as F

import torchvision.transforms as transforms


tp = transforms.ToTensor()

# For Distance Corrrelation Defense
def pairwise_dist(A, B):
    """
    Computes pairwise distances between each elements of A and each elements of
    B.
    Args:
        A,    [m,d] matrix
        B,    [n,d] matrix
    Returns:
        D,    [m,n] matrix of pairwise distances
    """
    # with tf.variable_scope('pairwise_dist'):
    # squared norms of each row in A and B
    na = torch.sum(torch.square(A), 1)
    nb = torch.sum(torch.square(B), 1)

    # na as a row and nb as a column vectors
    na = torch.reshape(na, [-1, 1])
    nb = torch.reshape(nb, [1, -1])

    # return pairwise euclidead difference matrix
    D = torch.sqrt(torch.maximum(na - 2 * torch.mm(A, B.T) + nb + 1e-20, torch.tensor(0.0)))
    return D


def compute_metrics(B_pred_thresh, B_true, dims=None):
    n_edges_pred = B_pred_thresh.sum()
    if B_true is not None:
        diff = B_true != B_pred_thresh
        score = diff.sum()
        shd = score - (((diff == diff.transpose()) & (diff != 0)).sum() / 2)
        recall = (B_true.astype(bool) & B_pred_thresh.astype(bool)).sum() / np.clip(
            B_true.sum(), 1, None
        )
        precision = (B_true.astype(bool) & B_pred_thresh.astype(bool)).sum() / np.clip(
            B_pred_thresh.sum(), 1, None
        )
    else:
        recall = "na"
        precision = "na"
        score = "na"
        shd = "na"
    if dims is not None and B_true is not None:
        # 生成节点分区ID
        group_ids = np.repeat(np.arange(len(dims)), dims)
        # 创建intra和inter的掩码
        intra_mask = group_ids[:, None] == group_ids[None, :]
        inter_mask = ~intra_mask

        # 计算intra边的指标
        true_intra = (B_true * intra_mask).sum()
        pred_intra = (B_pred_thresh * intra_mask).sum()
        TP_intra = ((B_true.astype(bool)) & (B_pred_thresh.astype(bool)) & intra_mask).sum()
        intra_precision = TP_intra / np.clip(pred_intra, 1, None)
        intra_recall = TP_intra / np.clip(true_intra, 1, None)

        # 计算inter边的指标
        true_inter = (B_true * inter_mask).sum()
        pred_inter = (B_pred_thresh * inter_mask).sum()
        TP_inter = ((B_true.astype(bool)) & (B_pred_thresh.astype(bool)) & inter_mask).sum()
        inter_precision = TP_inter / np.clip(pred_inter, 1, None)
        inter_recall = TP_inter / np.clip(true_inter, 1, None)

        # 统计 multi 结点和 single 结点
        multi_nodes = []
        single_nodes = []
        num_nodes = B_true.shape[0]
        for node in range(num_nodes):
            source_nodes = np.nonzero(B_true[:, node])[0]
            source_groups = group_ids[source_nodes]
            if len(np.unique(source_groups)) > 1:
                multi_nodes.append(node)
            else:
                single_nodes.append(node)

        # 创建 multi 边和 single 边的掩码
        multi_mask = np.zeros_like(B_true, dtype=bool)
        single_mask = np.zeros_like(B_true, dtype=bool)
        for node in multi_nodes:
            multi_mask[:, node] = True
        for node in single_nodes:
            single_mask[:, node] = True

        # 计算 multi 边的指标
        true_multi = (B_true * multi_mask).sum()
        pred_multi = (B_pred_thresh * multi_mask).sum()
        TP_multi = ((B_true.astype(bool)) & (B_pred_thresh.astype(bool)) & multi_mask).sum()
        multi_precision = TP_multi / np.clip(pred_multi, 1, None)
        multi_recall = TP_multi / np.clip(true_multi, 1, None)

        # 计算 single 边的指标
        true_single = (B_true * single_mask).sum()
        pred_single = (B_pred_thresh * single_mask).sum()
        TP_single = ((B_true.astype(bool)) & (B_pred_thresh.astype(bool)) & single_mask).sum()
        single_precision = TP_single / np.clip(pred_single, 1, None)
        single_recall = TP_single / np.clip(true_single, 1, None)

    else:
        intra_precision = "na"
        intra_recall = "na"
        inter_precision = "na"
        inter_recall = "na"
        multi_precision = "na"
        multi_recall = "na"
        single_precision = "na"
        single_recall = "na"
    return {
        "score": score,
        "shd": shd,
        "precision": precision,
        "recall": recall,
        "n_edges_pred": n_edges_pred,
        "intra_precision": intra_precision,
        "intra_recall": intra_recall,
        "inter_precision": inter_precision,
        "inter_recall": inter_recall,
        "multi_precision": multi_precision,
        "multi_recall": multi_recall,
        "single_precision": single_precision,
        "single_recall": single_recall
    }

def tf_distance_cov_cor(input1, input2, debug=False):
    # n = tf.cast(tf.shape(input1)[0], tf.float32)
    n = torch.tensor(float(input1.size()[0]))
    a = pairwise_dist(input1, input1)
    b = pairwise_dist(input2, input2)

    # A = a - tf.reduce_mean(a,axis=1) - tf.expand_dims(tf.reduce_mean(a,axis=0),axis=1) + tf.reduce_mean(a)
    A = a - torch.mean(a, axis=1) - torch.unsqueeze(torch.mean(a, axis=0), axis=1) + torch.mean(a)
    B = b - torch.mean(b, axis=1) - torch.unsqueeze(torch.mean(b, axis=0), axis=1) + torch.mean(b)

    dCovXY = torch.sqrt(torch.sum(A * B) / (n ** 2) + 1e-16)
    dVarXX = torch.sqrt(torch.sum(A * A) / (n ** 2))
    dVarYY = torch.sqrt(torch.sum(B * B) / (n ** 2) + 1e-16)

    dCorXY = dCovXY / torch.sqrt(dVarXX * dVarYY)
    if debug:
        print(("tf distance cov: {} and cor: {}, dVarXX: {}, dVarYY:{}").format(
            dCovXY, dCorXY, dVarXX, dVarYY))
    # return dCovXY, dCorXY
    return dCorXY


def sharpen(probabilities, T):
    if probabilities.ndim == 1:
        # print("here 1")
        tempered = torch.pow(probabilities, 1 / T)
        tempered = (
                tempered
                / (torch.pow((1 - probabilities), 1 / T) + tempered)
        )

    else:
        # print("here 2")
        tempered = torch.pow(probabilities, 1 / T)
        tempered = tempered / tempered.sum(dim=-1, keepdim=True)

    return tempered


def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))

def append_exp_res(path, res):
    with open(path, 'a', encoding='utf-8') as f:
        f.write(res + '\n')
