import argparse
from time import time
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from DyCausal import DyCausal
from utils.evaluation import MetricsDAG
from utils.simulator import simulate_var, simulate_dag, simulate_dy_var, simulate_ins_dy_var

parser = argparse.ArgumentParser(description='Configuration')
config = parser.parse_args(args=[])
config.n = 200
config.t = 50
config.p = 20
config.e = 2
config.lag = 1
config.es = [2]
config.dims = [config.p, 1]
config.graph_type = 'ER'
config.sem_type = 'lag'
config.model_type = 'AdditiveNoiseModel'
config.ins = True
config.kernel_size = 1
config.stride = 1
config.bias = True
config.w_th = 0.15
config.device_type = 'gpu'
config.device_ids = 0

# (10, 20, 1e-4), (50, 5e-5), (100, 1e-5), (150, 6e-6), (200, 3e-6), (300)
for _ in range(10):
    B_true = simulate_dag(config.p, config.e, config.graph_type, config.lag, config.ins, config.es)
    data, B_trues = simulate_dy_var(B_true, config.n, config.t + config.lag, config.p, config.lag, config.ins)
    # data, B_trues = simulate_ins_dy_var(B_true, config.n, config.t + config.lag, config.p, config.lag)
    print(data.shape)
    X = np.concatenate([data[:, i:config.t+i, :] for i in range(config.lag+1)], axis=2)
    Y = data[:, config.lag:, :]
    # sns.set_theme(rc={'figure.figsize':(30, 23)})
    # plt.subplot(3, 2, 1)
    # sns.heatmap(B_trues[0].T, cmap="RdBu", vmin=-2, vmax=2, xticklabels=False, yticklabels=False, cbar=False, linewidths=2)
    # plt.subplot(3, 2, 3)
    # sns.heatmap(B_trues[int(config.t/2)].T, cmap="RdBu", vmin=-2, vmax=2, xticklabels=False, yticklabels=False, cbar=False, linewidths=2)
    # plt.subplot(3, 2, 5)
    # sns.heatmap(B_trues[-1].T, cmap="RdBu", vmin=-2, vmax=2, xticklabels=False, yticklabels=False, cbar=False, linewidths=2)
    dycausal = DyCausal(config, [X, Y])
    begin_time = time()
    dycausal.train(lambda1=0.05, lambda2=0.001, T=4, mu_init=1, warm_iter=5e3, max_iter=8e3, lr=1e-4)
    end_time = time()
    # W_est = dycausal.get_dy_adj()
    # plt.subplot(3, 2, 2)
    # sns.heatmap(W_est[0].T, cmap="RdBu", vmin=-2, vmax=2, xticklabels=False, yticklabels=False, cbar=False, linewidths=2)
    # plt.subplot(3, 2, 4)
    # sns.heatmap(W_est[int(W_est.shape[0]/2)].T, cmap="RdBu", vmin=-2, vmax=2, xticklabels=False, yticklabels=False, cbar=True, linewidths=2)
    # plt.subplot(3, 2, 6)
    # sns.heatmap(W_est[-1].T, cmap="RdBu", vmin=-2, vmax=2, xticklabels=False, yticklabels=False, cbar=False, linewidths=2)
    # plt.savefig('B_true1.pdf')
    W_est, W = dycausal.get_adj()
    if W_est.shape[1] < B_trues[0].shape[0]:
        W_est = np.concatenate([np.zeros((W_est.shape[0], config.p * (3 - config.lag), config.p)), W_est], axis=1)
    if W_est.shape[1] > B_trues[0].shape[0]:
        W_est = W_est[:, config.p * (config.lag - 3):, :]
    print(W_est[0].shape, B_trues[0].shape)
    # file_handle = open('DyCausal_dy_result.txt', 'a')
    met = MetricsDAG(W_est[0], (abs(B_trues[0]) > 0.1).astype(int)).metrics
    print(met)
    # file_handle.write('type:{},time:{},matrics{}\n'.format('DyCausal-W1' + str(config.p), end_time - begin_time, str(met)))
    met = MetricsDAG(W_est[int(W_est.shape[0]/2)], (abs(B_trues[int(config.t/2)]) > 0.1).astype(int)).metrics
    print(met)
    # file_handle.write('type:{},time:{},matrics{}\n'.format('DyCausal-W25' + str(config.p), end_time - begin_time, str(met)))
    met = MetricsDAG(W_est[-1], (abs(B_trues[-1]) > 0.1).astype(int)).metrics
    print(met)
    # file_handle.write('type:{},time:{},matrics{}\n'.format('DyCausal-W50' + str(config.p), end_time - begin_time, str(met)))
    # np.save('W_est.npy', W)
    # np.save('B_trues.npy', np.array(B_trues))
    tpr = []
    fdr = []
    f1 = []
    shd = []
    for i in range(W.shape[0]):
        met = MetricsDAG(abs(W[i] > 0.15).astype(int), (abs(B_trues[i]) > 0).astype(int)).metrics
        tpr.append(f"{met['tpr'] * 100:.2f}")
        fdr.append(f"{met['fdr'] * 100:.2f}")
        f1.append(f"{met['F1'] * 100:.2f}")
        shd.append(f"{met['shd']:.2f}")
    print(tpr)
    print(fdr)
    print(f1)
    print(shd)
    # file_handle.close()