import argparse
from time import time
import numpy as np
from DyCausal import DyCausal
from utils.simulator import simulate_lorenz_96

parser = argparse.ArgumentParser(description='Configuration')
config = parser.parse_args(args=[])
config.n = 20
config.t = 50
config.p = 10
config.e = 2
config.lag = 1
config.es = [2]
config.dims = [config.p, 10, 1]
config.graph_type = 'ER'
config.sem_type = 'ode'
config.model_type = 'AdditiveNoiseModel'
config.ins = False
config.kernel_size = 10
config.stride = 5
config.bias = True
config.w_th = 0.3
config.device_type = 'gpu'
config.device_ids = 2

for i,j,k,l in [(10, 0.05, 0.05, 0.3), (20, 0.03, 0.05, 0.3), (40, 0.01, 0.04, 0.3), (80, 0.001, 0.005, 0.3)]:
    config.p = i
    config.dims = [config.p, 10, 1]
    lambda1 = j
    lambda2 = k
    config.w_th = l
    for _ in range(10):
        data, graph = simulate_lorenz_96(config.p, config.n * (config.t + config.lag), delta_t=0.05)
        data = data.reshape((config.n, config.t + config.lag, config.p))
        X = np.concatenate([data[:, i:i-config.lag, :] for i in range(config.lag)], axis=2)
        Y = data[:, config.lag:, :]
        dycausal = DyCausal(config, [X, Y])
        begin_time = time()
        dycausal.train(lambda1=lambda1, lambda2=lambda2, T=1, max_iter=1.5e4, lr=0.005)
        end_time = time()
        W_est = dycausal.get_adj()
        met = dycausal.station_metric(W_est, graph.T)
        file_handle = open('DyCausal_result_ode.txt', 'a')
        file_handle.write('type:{},time:{},matrics{}\n'.format('DyCausal-ode' + str(config.p), end_time - begin_time, str(met)))
        file_handle.close()
        print(met)