import argparse
from time import time
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from DyCausal import DyCausal
from DyCausal_h import DyCausalh
from utils.simulator import simulate_dag, simulate_nonlinear_sem

parser = argparse.ArgumentParser(description='Configuration')
config = parser.parse_args(args=[])
config.n = 20
config.t = 50
config.p = 10
config.e = 2
config.lag = 2
config.es = [2, 2]
config.dims = [config.p, 10, 1]
config.graph_type = 'ER'
config.sem_type = 'lag'
config.model_type = 'AdditiveNoiseModel'
config.ins = True
config.kernel_size = 10
config.stride = 5
config.bias = True
config.w_th = 0.25
config.device_type = 'gpu'
config.device_ids = 0

# (10, 0.009, 0.005, 0.3), (20, 0.011, 0.005, 0.3), (40, 0.015, 0.005, 0.25), (80, 0.025, 0.005, 0.15)
for i,j,k,l in [(10, 0.009, 0.005, 0.3)]:
    config.p = i
    config.dims = [config.p, 10, 1]
    lambda1 = j
    lambda2 = k
    config.w_th = l
    for _ in range(1):
        B_true = simulate_dag(config.p, config.e, config.graph_type, config.lag, config.ins, config.es)
        data = simulate_nonlinear_sem(B_true, config.n * (config.t + config.lag), config.model_type, config.p, config.lag, config.ins)
        data = data.reshape((config.n, config.t + config.lag, config.p))
        X = np.concatenate([data[:, i:config.t+i, :] for i in range(config.lag + 1)], axis=2)
        Y = data[:, config.lag:, :]
        # h_norm
        dycausal = DyCausal(config, [X, Y])
        begin_time = time()
        dycausal.train(lambda1=lambda1, lambda2=lambda2, T=4, mu_init=1, warm_iter=7e3, max_iter=1e4, lr=0.001)
        end_time = time()
        W_est = dycausal.get_adj()
        met = dycausal.station_metric(W_est, B_true[:, config.p * config.lag:])
        print(met)
        dycausal.draw_h()

        plt.cla()

        # h_log
        dycausalh = DyCausalh(config, [X, Y])
        begin_time = time()
        dycausalh.train(lambda1=lambda1, lambda2=lambda2, T=4, mu_init=1, warm_iter=7e3, max_iter=1e4, lr=0.001)
        end_time = time()
        W_h_est = dycausalh.get_adj()
        met = dycausalh.station_metric(W_h_est, B_true[:, config.p * config.lag:])
        print(met)
        dycausalh.draw_h()

        plt.cla()

        sns.set_theme(rc={'figure.figsize': (30, 28)})
        plt.subplot(1, 3, 1)
        sns.heatmap(B_true[:, config.p * config.lag:], cmap="RdBu", vmin=-2, vmax=2, xticklabels=False, yticklabels=False, cbar=False, linewidths=2)
        plt.subplot(1, 3, 2)
        sns.heatmap(W_est[0], cmap="RdBu", vmin=-2, vmax=2, xticklabels=False, yticklabels=False, cbar=False, linewidths=2)
        plt.subplot(1, 3, 3)
        sns.heatmap(W_h_est[0], cmap="RdBu", vmin=-2, vmax=2, xticklabels=False, yticklabels=False, cbar=False, linewidths=2)
        plt.savefig('h_norm_vs_h_log.pdf')
        # met = dycausal.station_metric(W_est, B_true[:, config.p * config.lag:])
        # file_handle = open('DyCausal_result_nonlinear.txt', 'a')
        # file_handle.write('type:{},time:{},matrics{}\n'.format('DyCausal-linearlag' + str(config.p), end_time - begin_time, str(met)))
        # file_handle.close()
        # print(met)