import argparse
from time import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score

from DyCausal import DyCausal
from utils.evaluation import MetricsDAG

parser = argparse.ArgumentParser(description='Configuration')
config = parser.parse_args(args=[])
config.n = 200
config.t = 50
config.p = 5
config.e = 1
config.lag = 1
config.es = [1]
config.dims = [config.p, 1]
config.graph_type = 'ER'
config.sem_type = 'lag'
config.model_type = 'AdditiveNoiseModel'
config.ins = False
config.kernel_size = 1
config.stride = 1
config.bias = True
config.w_th = 0.015
config.device_type = 'gpu'
config.device_ids = 0


for data_name in ['medical']:  # 'traffic', 'pm25', 'medical'
    data = np.load('datasets/causaltime_gen_ver1.0/'+data_name+'/gen_data.npy').astype(np.double)
    B_true = np.load('datasets/causaltime_gen_ver1.0/'+data_name+'/graph.npy')
    data = data[:, :20, :]
    for i in range(data.shape[0]):
        mean = np.mean(data[i, :, :data.shape[2] // 2], axis=0)
        std = np.std(data[i, :, :data.shape[2] // 2], axis=0)
        data[i, :, :data.shape[2] // 2] = (data[i, :, :data.shape[2] // 2] - mean) / std
        data[i, :, data.shape[2] // 2:] = data[i, :, data.shape[2] // 2:] / std
    data = data[:, :, :data.shape[2] // 2] - data[:, :, data.shape[2] // 2:]
    config.n = data.shape[0]
    config.t = data.shape[1] - config.lag
    config.p = data.shape[2]
    config.dims = [config.p, 1]
    X = np.concatenate([data[:, i:config.t+i, :] for i in range(config.lag)], axis=2)
    Y = data[:, config.lag:, :]
    mean_W = np.zeros((39, config.p, config.p))
    begin_time = time()
    for i in range(100):
        dycausal = DyCausal(config, [X, Y])
        begin_time = time()
        dycausal.train(lambda1=0.01, lambda2=0.01, T=1, warm_iter=1e3, max_iter=1e3, lr=0.001)
        end_time = time()
        W_est, W = dycausal.get_adj()
        W_est = (np.sum(W_est, axis=0) > 0).astype(int)
        B_est = np.zeros_like(B_true)
        for i in range(config.lag):
            B_est += W_est[i*config.p:i*config.p+config.p, :config.p]  # W_est[i*config.p:i*config.p+config.p//2, :config.p//2]
        B_est = B_est + B_est.T
        B_est = (B_est > 0).astype(int)
        met = MetricsDAG(B_est, B_true).metrics
        y_score = B_est.reshape(-1)
        y_test = B_true.reshape(-1)
        auroc = roc_auc_score(y_test, y_score)
        auprc = average_precision_score(y_test, y_score)
        print(met, auroc, auprc)
        plt.matshow(B_est, cmap=plt.cm.gray)
        plt.show()
        plt.matshow(B_true, cmap=plt.cm.gray)
        plt.show()
        mean_W += W
    mean_W /= 100
    # print(mean_W.shape)
    # mean_W = np.concatenate((mean_W[:18], mean_W[20:38]), axis=0)
    np.save('mean_W.npy', mean_W)