import argparse
from time import time
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score

from DyCausal import DyCausal
from utils.evaluation import MetricsDAG

parser = argparse.ArgumentParser(description='Configuration')
config = parser.parse_args(args=[])
config.n = 50
config.t = 200
config.p = 15
config.e = 1
config.lag = 1
config.es = [1]
config.dims = [config.p, 10, 1]
config.graph_type = 'ER'
config.sem_type = 'ode'
config.model_type = 'AdditiveNoiseModel'
config.ins = False
config.kernel_size = 5
config.stride = 5
config.bias = True
config.w_th = 0.05
config.device_type = 'gpu'
config.device_ids = 0

for j in range(10):
    datas = []
    for i in range(50):
        data = np.loadtxt('datasets/netsim/data/sim3_'+str(i)+'_d.csv', delimiter=',')
        datas.append(data)
    datas = np.array(datas)
    for i in range(datas.shape[0]):
        datas[i] = (datas[i] - np.mean(datas[i], axis=0)) / np.std(datas[i], axis=0)
    config.n = datas.shape[0]
    config.t = datas.shape[1] - config.lag
    config.p = datas.shape[2]
    B_true = np.loadtxt('datasets/netsim/graph/sim3_g.csv', delimiter=',')
    X = np.concatenate([datas[:, i:config.t+i, :] for i in range(config.lag)], axis=2)
    Y = datas[:, config.lag:, :]
    dycausal = DyCausal(config, [X, Y])
    begin_time = time()
    dycausal.train(lambda1=0.01, lambda2=0.1, T=1, max_iter=5e3, lr=0.001)
    end_time = time()
    W_est, W = dycausal.get_adj()
    B_est = np.zeros_like(B_true)
    for i in range(W_est.shape[0]):
        B_est += W_est[i]
    B_est = (B_est>0).astype(int)
    met = MetricsDAG(B_est, B_true).metrics
    W_est = np.zeros_like(W[0])
    for j in range(W.shape[0]):
        W[j] = (W[j] - np.min(W[j])) / (np.max(W[j]) - np.min(W[j]))
        W_est += W[j]
    W_est = (W_est - np.min(W_est)) / (np.max(W_est) - np.min(W_est))
    y_score = B_est.reshape(-1)
    y_test = B_true.reshape(-1)
    auroc = roc_auc_score(y_test, y_score)
    auprc = average_precision_score(y_test, y_score)
    print(met, auroc, auprc)
    plt.imshow(B_true, cmap='gray')
    plt.show()
    plt.imshow(B_est, cmap='gray')
    plt.show()