import itertools
import numpy as np
import sdeint
import time
import pickle
import os
import sys
from tqdm import tqdm
from joblib import Parallel, delayed

tau = 3
sigma = 1
T = 1000
h = 1e-1

def gen_activity(W, seed=None):

    if seed is not None:
        generator = np.random.default_rng(seed)
    else:
        generator = np.random.default_rng()
    # f
    def f_(x, t):
        return 1/tau * (-1 * np.eye(W.shape[0]) @ x + W @ x)

    # G: linear i.i.d noise with sigma
    def g_(x, t):
        return sigma * np.eye(W.shape[0])

    # Generate random initial condition and then integrate over the desired time period
    tspace = np.linspace(0, T, int(T/h))
    
    x0 = generator.normal(size=(W.shape[0],))

    return  sdeint.itoSRI2(f_, g_, x0, tspace, generator=generator)    

def inner_loop(task, R, Alist, savepath):
    t0 = time.time()
    r, rep, inner_rep = task
    ridx = list(R).index(r)
    A = Alist[rep][ridx]
    x = gen_activity(A, seed=inner_rep)

    with open('%s/x_%d_%d_%d.pkl' % (savepath, ridx, rep, inner_rep), 'wb') as f:
        f.write(pickle.dumps(x))

    print(time.time() - t0)


if __name__ == '__main__':

    Amats = sys.argv[1]
    savepath = sys.argv[2]
    njobs = int(sys.argv[3])
    if not os.path.exists(savepath):
        os.makedirs(savepath)    

    inner_reps = 10
    #with open('soc_Alist.pkl', 'rb') as f:
    #    Alist = pickle.load(f)
    with open(Amats, 'rb') as f:
        Alist = pickle.load(f)
        R = pickle.load(f)
    
    reps= len(Alist)
    tasks = itertools.product(R, range(reps), range(inner_reps))
    Parallel(n_jobs=njobs)(delayed(inner_loop)(task, R, Alist, savepath) for task in tqdm(tasks))
