import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import torch
import numpy as np
from sklearn import metrics

from synthetic import simulate_var


class GC_draw_manager:
    def __init__(self, GC_true, penalty_sum, seg):
        self.penalty_sum = penalty_sum
        self.GC_true = GC_true
        self.seg = seg
        self.fig, self.ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 7))
        axfreq1 = plt.axes([0.25, 0.2, 0.65, 0.03], facecolor='lightgoldenrodyellow')
        axfreq2 = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor='lightgoldenrodyellow')
        self.sfreq1 = Slider(axfreq1, 'begin', valmin=0, valmax=seg, valfmt='%d', valinit=0, valstep=1)
        self.sfreq2 = Slider(axfreq2, 'end', valmin=0, valmax=seg, valfmt='%d', valinit=0, valstep=1)
        self.sfreq1.on_changed(self.update)
        self.sfreq2.on_changed(self.update)
        self.sfreq1.reset()
        self.sfreq2.reset()
        self.sfreq1.set_val(0)
        self.sfreq2.set_val(5)
        return

    def update(self, val):
        freq1 = self.sfreq1.val
        freq2 = self.sfreq2.val
        self.ax.clear()
        GC = self.GC_var(self.penalty_sum,self.seg,freq1,freq2,threshold=7)
        # X_np1, beta, GC1 = simulate_var(p=10, T=500, lag=3, seed=int(freq1), sparsity=0.2)
        # X_np2, beta, GC2 = simulate_var(p=10, T=500, lag=3, seed=int(freq2), sparsity=0.2)
        # self.make_fig2(GC1,GC2)
        self.make_fig2(self.GC_true, GC)
        self.fig.canvas.draw_idle()
        return

    def make_fig2(self, GC, GC_est):
        self.ax.imshow(GC_est, cmap='Blues', vmin=0, vmax=1, extent=(0, len(GC_est), len(GC_est), 0))
        self.ax.set_title('GC estimated')
        self.ax.set_ylabel('Affected series')
        self.ax.set_xlabel('Causal series')
        self.ax.set_xticks([])
        self.ax.set_yticks([])

        # Mark disagreements
        for i in range(len(GC_est)):
            for j in range(len(GC_est)):
                if GC[i, j] != GC_est[i, j]:
                    rect = plt.Rectangle((j, i - 0.05), 1, 1, facecolor='none', edgecolor='red', linewidth=1)
                    self.ax.add_patch(rect)

        plt.show()

    def GC_var(self, penalty_sum, seg, seg_begin, seg_end, threshold=0):
        GC = torch.Tensor()
        for penalty_t in penalty_sum:
            # 计算第一个层次的权重矩阵的范数,要沿着第0个和第2个维度进行计算。[100,10,5]
            T = len(penalty_t)
            # 根据序列的划分来分析因果
            weight_norm_list = [torch.norm(penalty_t[i * int(T / seg):(i + 1) * int(T / seg)], dim=(0, 2)) for i in
                                range(seg)]
            weight_norm_list = torch.stack(weight_norm_list)

            GC = torch.cat((GC, weight_norm_list.unsqueeze(0)), dim=0)
            # 将范数添加到GC列表中
        GC = GC.transpose(0, 1)
        # !!!选择要进行求和的序列
        GC1 = torch.norm(GC[seg_begin:seg_end, :, :], dim=0)
        # 转换为张量
        # GC1 = GC1.cpu().detach().numpy()
        if threshold!=0:
            return (GC1 > threshold).int()
        else:
            return GC1

    # GC：标准答案，GC_est:跑的结果
    def make_fig(self, GC, GC_est):
        # Make figures
        fig, axarr = plt.subplots(1, 2, figsize=(16, 5))
        axarr[0].imshow(GC, cmap='Blues')
        axarr[0].set_title('GC actual')
        axarr[0].set_ylabel('Affected series')
        axarr[0].set_xlabel('Causal series')
        axarr[0].set_xticks([])
        axarr[0].set_yticks([])

        axarr[1].imshow(GC_est, cmap='Blues', vmin=0, vmax=1, extent=(0, len(GC_est), len(GC_est), 0))
        axarr[1].set_title('GC estimated')
        axarr[1].set_ylabel('Affected series')
        axarr[1].set_xlabel('Causal series')
        axarr[1].set_xticks([])
        axarr[1].set_yticks([])

        # Mark disagreements
        for i in range(len(GC_est)):
            for j in range(len(GC_est)):
                if GC[i, j] != GC_est[i, j]:
                    rect = plt.Rectangle((j, i - 0.05), 1, 1, facecolor='none', edgecolor='red', linewidth=1)
                    axarr[1].add_patch(rect)

        plt.show()

# 传入两个GC矩阵，画出ROC图并返回曲线下面积  GC是真值  GC_est是预测值
def draw_GC_ROC_curve(GC, GC_est_original,draw = True):
    GC = GC.flatten().astype(np.float)
    GC_est = GC_est_original.flatten().astype(np.float)
    if draw:
        score = draw_ROC_curve(GC, GC_est / GC_est.max())
    else:
        score = metrics.roc_auc_score(GC, GC_est / GC_est.max())
    return score

# 画ROC图
def draw_ROC_curve(label, pre):
    FPR, TPR, P = metrics.roc_curve(label, pre)
    plt.plot(FPR, TPR, 'b*-', label='roc')
    plt.plot([0, 1], [0, 1], 'r--', label="45°")
    plt.legend()
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.show()
    AUC_score = metrics.auc(FPR, TPR)
    return AUC_score
