import argparse
from time import time
import numpy as np
from DyCausal import DyCausal
from utils.simulator import simulate_var, simulate_dag

parser = argparse.ArgumentParser(description='Configuration')
config = parser.parse_args(args=[])
config.n = 20
config.t = 50
config.p = 10
config.e = 2
config.lag = 2
config.es = [2, 2]
config.dims = [config.p, 1]
config.graph_type = 'ER'
config.sem_type = 'lag'
config.model_type = 'AdditiveNoiseModel'
config.ins = True
config.kernel_size = 10
config.stride = 5
config.bias = True
config.w_th = 0.1
config.device_type = 'gpu'
config.device_ids = 0

#10, 0.05 0.001
for i in[10, 20, 40, 80]:
    config.p = i
    config.dims = [config.p, 1]
    for _ in range(10):
        B_true = simulate_dag(config.p, config.e, config.graph_type, config.lag, config.ins, config.es)
        data = simulate_var(B_true, config.n * (config.t + config.lag), config.p, config.lag, config.ins)
        data = data.reshape((config.n, config.t + config.lag, config.p))
        X = np.concatenate([data[:, i:config.t+i, :] for i in range(config.lag + 1)], axis=2)
        Y = data[:, config.lag:, :]
        dycausal = DyCausal(config, [X, Y])
        begin_time = time()
        dycausal.train(lambda1=0.05, lambda2=0.001, T=4, mu_init=1, warm_iter=5e3, max_iter=8e3, lr=0.005)
        end_time = time()
        W_est = dycausal.get_adj()
        met = dycausal.station_metric(W_est, B_true[:, config.p * config.lag:])
        file_handle = open('DyCausal_result_linear.txt', 'a')
        file_handle.write('type:{},time:{},matrics{}\n'.format('DyCausal-linearlag' + str(config.p), end_time - begin_time, str(met)))
        file_handle.close()
        print(met)