import itertools
import numpy as np
import sdeint
import pdb
import time
import pickle
import os
import sys
from tqdm import tqdm
from joblib import Parallel, delayed

tau = 3
sigma = 1
T = 1000
h = 1e-1

def gen_activity(Wseq, seed=None):

    if seed is not None:
        generator = np.random.default_rng(seed)
    else:
        generator = np.random.default_rng()

    # Apply a slow unitary rotation to W - this preserves spectrum and
    # non-normality
    interp_intervals = np.linspace(0, T, len(Wseq))  
    interval_widths = np.diff(interp_intervals)[0]

    def Wt(t):
        if np.isclose(t, 0):
            interval = 0
        else:
             interval = np.where(interp_intervals - t < 0)[0][-1]
        return Wseq[interval]

    # f
    def f_(x, t):
        W = Wt(t)
        return 1/tau * (-1 * np.eye(W.shape[0]) @ x + W @ x)

    # G: linear i.i.d noise with sigma
    def g_(x, t):
        return sigma * np.eye(Wseq[0].shape[0])

    # Generate random initial condition and then integrate over the desired time period
    tspace = np.linspace(0, T, int(T/h))
    
    x0 = generator.normal(size=(Wseq[0].shape[0],))

    return  sdeint.itoSRI2(f_, g_, x0, tspace, generator=generator)    

def inner_loop(task, R, Alist, savepath):
    t0 = time.time()
    r, rep, inner_rep = task
    ridx = list(R).index(r)
    rep = int(rep)
    inner_rep = int(inner_rep)
    rnd = np.random.default_rng(rep)
    seq_idxs = rnd.choice(np.arange(reps), 3)    
    Aseq = [Alist[s][ridx] for s in seq_idxs]
    nn = [np.linalg.norm(A @ A.T - A.T @ A) for A in Aseq]
    x = gen_activity(Aseq, seed=inner_rep)

    with open('%s/x_%d_%d_%d.pkl' % (savepath, ridx, rep, inner_rep), 'wb') as f:
        f.write(pickle.dumps(x))
        f.write(pickle.dumps(seq_idxs))

    print(time.time() - t0)

if __name__ == '__main__':

    Amats = sys.argv[1]
    savepath = sys.argv[2]
    njobs = int(sys.argv[3])
    if not os.path.exists(savepath):
        os.makedirs(savepath)    

    inner_reps = 10

    with open(Amats, 'rb') as f:
        Alist = pickle.load(f)
        R = pickle.load(f)
    
    reps= len(Alist)
    tasks = list(itertools.product(R, range(reps), range(inner_reps)))
    Parallel(n_jobs=njobs)(delayed(inner_loop)(task, R, Alist, savepath) for task in tqdm(tasks))
