import argparse
from time import time
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score

from DyCausal import DyCausal
from utils.evaluation import MetricsDAG

parser = argparse.ArgumentParser(description='Configuration')
config = parser.parse_args(args=[])
config.n = 50
config.t = 200
config.p = 10
config.e = 1
config.lag = 1
config.es = [1]
config.dims = [config.p, 10, 1]
config.graph_type = 'ER'
config.sem_type = 'ode'
config.model_type = 'AdditiveNoiseModel'
config.ins = False
config.kernel_size = 5
config.stride = 5
config.bias = True
config.w_th = 0.1
config.device_type = 'gpu'
config.device_ids = 0

for i in range(10):
    B_true = np.loadtxt('datasets/dream4/graph/graph1.txt', delimiter=',')
    plt.matshow(np.abs(B_true), cmap=plt.cm.gray)
    plt.show()
    data = np.loadtxt('datasets/dream4/data/data1.txt', delimiter=',')
    data = (data - np.mean(data, axis=0)) / np.std(data, axis=0)
    data = data.reshape((-1, 21, data.shape[1]))
    config.n = data.shape[0]
    config.t = data.shape[1] - config.lag
    config.p = data.shape[2]
    config.dims = [config.p, 1]
    X = np.concatenate([data[:, i:i-config.lag, :] for i in range(config.lag)], axis=2)
    Y = data[:, config.lag:, :]
    dycausal = DyCausal(config, [X, Y])
    begin_time = time()
    dycausal.train(lambda1=0.1, lambda2=0.1, T=1, max_iter=3e3, lr=0.001)
    end_time = time()
    W_est, W = dycausal.get_adj()
    B_est1 = np.zeros_like(W_est[0])
    for i in range(1, W_est.shape[0]):
        B_est1 += W_est[i]
        plt.imshow(W_est[i], cmap='gray')
        plt.show()
    B_est2 = np.zeros_like(B_true)
    for i in range(config.lag):
        B_est2 += B_est1[i * config.p: (i + 1) * config.p, :]
    B_est = (B_est2 > 0).astype(int)
    for j in range(B_est.shape[0]):
        B_est[j, j] = 0
    met = MetricsDAG(B_est, B_true).metrics
    W_est = np.zeros_like(W[0])
    for j in range(W.shape[0]):
        W[j] = (W[j] - np.min(W[j])) / (np.max(W[j]) - np.min(W[j]))
        W_est += W[j]
    W_est = (W_est - np.min(W_est)) / (np.max(W_est) - np.min(W_est))
    y_score = W_est.reshape(-1)
    y_test = B_true.reshape(-1)
    auroc = roc_auc_score(y_test, y_score)
    auprc = average_precision_score(y_test, y_score)
    print(met, auroc, auprc)
    plt.imshow(B_est, cmap='gray')
    plt.show()