import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import glob
import tqdm
import torch
import scipy
import random
import ipdb as pdb
import numpy as np
from torch import nn
from torch.nn import init
from collections import deque
import matplotlib.pyplot as plt
from sklearn import preprocessing
from scipy.stats import ortho_group
from sklearn.preprocessing import scale
from IFactor.tools.utils import create_sparse_transitions, controlable_sparse_transitions, get_one_hot_ndarray

VALIDATION_RATIO = 0.2
root_dir = os.path.expanduser('anonymized_path/data')
standard_scaler = preprocessing.StandardScaler()

def leaky_ReLU_1d(d, negSlope):
    if d > 0:
        return d
    else:
        return d * negSlope

leaky1d = np.vectorize(leaky_ReLU_1d)

def leaky_ReLU(D, negSlope):
    assert negSlope > 0
    return leaky1d(D, negSlope)

def weigth_init(m):
    if isinstance(m, nn.Conv2d):
        init.xavier_uniform_(m.weight.data)
        init.constant_(m.bias.data,0.1)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0,0.01)
        m.bias.data.zero_()

def sigmoidAct(x):
    return 1. / (1 + np.exp(-1 * x))

def generateUniformMat(Ncomp, condT):
    """
    generate a random matrix by sampling each element uniformly at random
    check condition number versus a condition threshold
    """
    A = np.random.uniform(0, 2, (Ncomp, Ncomp)) - 1
    for i in range(Ncomp):
        A[:, i] /= np.sqrt((A[:, i] ** 2).sum())

    while np.linalg.cond(A) > condT:
        # generate a new A matrix!
        A = np.random.uniform(0, 2, (Ncomp, Ncomp)) - 1
        for i in range(Ncomp):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())

    return A

def generateUniformMatS(shape, condT):
    """
    generate a random matrix by sampling each element uniformly at random
    check condition number versus a condition threshold
    """
    A = np.random.uniform(0, 2, shape) - 1
    for i in range(shape[1]):
        A[:, i] /= np.sqrt((A[:, i] ** 2).sum())

    while np.linalg.cond(A) > condT:
        # generate a new A matrix!
        A = np.random.uniform(0, 2, shape) - 1
        for i in range(shape[1]):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())
    return A

def noisecoupled_gaussian_ts_reward_action(action_num=10):
    lags = 2
    Nlayer = 3
    length = 1
    z1condList, z2condList, z3condList, z4condList, rcondList = [], [], [], [], []
    negSlope = 0.2
    z1_dim, z2_dim, z3_dim, z4_dim = 2, 2, 2, 2
    latent_size = z1_dim + z2_dim + z3_dim + z4_dim 
    transitions = []
    z1tran, z2tran, z3tran, z4tran = [], [], [], []
    noise_scale = 0.1
    batch_size = 100000
    # batch_size = 100
    Niter4condThresh = 1e4

    path = os.path.join(root_dir, f"noisecoupled_gaussian_ts_2lag_IFactor_{action_num}_actions")
    os.makedirs(path, exist_ok=True)

    for i in range(int(Niter4condThresh)):
        # A = np.random.uniform(0,1, (Ncomp, Ncomp))
        z1A = np.random.uniform(1, 2, (z1_dim+z2_dim+1, z1_dim))  # - 1
        z2A = np.random.uniform(1, 2, (z1_dim+z2_dim, z2_dim))  # - 1
        z3A = np.random.uniform(1, 2, (latent_size+1, z3_dim))  # - 1
        z4A = np.random.uniform(1, 2, (latent_size, z4_dim))  # - 1
        rA = np.random.uniform(1, 2, (z1_dim+z2_dim, 1))  # - 1
        for i in range(z1_dim):
            z1A[:, i] /= np.sqrt((z1A[:, i] ** 2).sum())
        for i in range(z2_dim):
            z2A[:, i] /= np.sqrt((z2A[:, i] ** 2).sum())
        for i in range(z3_dim):
            z3A[:, i] /= np.sqrt((z3A[:, i] ** 2).sum())
        for i in range(z4_dim):
            z4A[:, i] /= np.sqrt((z4A[:, i] ** 2).sum())
        
        z1condList.append(np.linalg.cond(z1A))
        z2condList.append(np.linalg.cond(z2A))
        z3condList.append(np.linalg.cond(z3A))
        z4condList.append(np.linalg.cond(z4A))
    print('get threshold')
    # reward
    rA = np.random.uniform(1, 2, (z1_dim+z2_dim, 1))  # - 1
    rA[:, 0] /= np.sqrt((rA[:, 0] ** 2).sum())
    # transitions
    z1condThresh = np.percentile(z1condList, 25)  # only accept those below 25% percentile
    z2condThresh = np.percentile(z2condList, 25)
    z3condThresh = np.percentile(z3condList, 25)
    z4condThresh = np.percentile(z4condList, 25)
    
    for l in range(lags):
        z1B = generateUniformMatS((z1_dim+z2_dim+1, z1_dim), z1condThresh)
        z2B = generateUniformMatS((z1_dim+z2_dim, z2_dim), z2condThresh)
        z3B = generateUniformMatS((latent_size+1, z3_dim), z3condThresh)
        z4B = generateUniformMatS((latent_size, z4_dim), z4condThresh)
        z1tran.append(z1B)
        z2tran.append(z2B)
        z3tran.append(z3B)
        z4tran.append(z4B)
    z1tran.reverse()
    z2tran.reverse()
    z3tran.reverse()
    z4tran.reverse()
    print('get transition matrix')
    # 
    mixingList, rewmixingList = [], []
    for l in range(Nlayer - 1):
        # generate causal matrix first:
        A = ortho_group.rvs(latent_size)  # generateUniformMat(Ncomp, condThresh)
        B = ortho_group.rvs(latent_size) 
        mixingList.append(A)

    # Generate 10 random numbers from a standard normal distribution
    random_action = np.random.randn(action_num)

    # Generate a 3D array (5, 5, 5) with elements randomly chosen from the generated numbers
    y_l = np.random.normal(0, 1, (batch_size, lags, latent_size))
    a_l = np.random.choice(random_action, (batch_size, lags, 1))

    y_l = (y_l - np.mean(y_l, axis=0 ,keepdims=True)) / np.std(y_l, axis=0 ,keepdims=True)

    yt = []; xt = []; at = []; rt = []
    for i in range(lags):
        yt.append(y_l[:,i,:])
        at.append(a_l[:,i,:])
        rt.append(leaky_ReLU(np.dot(y_l[:,i,:z1_dim+z2_dim], rA), negSlope))
    mixedDat = np.copy(y_l)
    for l in range(Nlayer - 1):
        mixedDat = leaky_ReLU(mixedDat, negSlope)
        mixedDat = np.dot(mixedDat, mixingList[l])
    x_l = np.copy(mixedDat)
    for i in range(lags):
        xt.append(x_l[:,i,:])
        
    # Mixing function
    for i in range(length):
        # Transition function
        y_t = np.random.normal(0, noise_scale, (batch_size, latent_size))
        # Modulate the noise scale with averaged history
        y_t = y_t * np.mean(y_l, axis=1)
        a_t = np.random.choice(random_action, (batch_size, 1))
        print(a_t.shape)
        at.append(a_t)
        for l in range(lags):
            z1t = np.dot(np.concatenate((y_l[:,l,:z1_dim+z2_dim], a_l[:,l,:]), axis=-1), z1tran[l])
            z2t = np.dot(y_l[:,l,:z1_dim+z2_dim], z2tran[l])
            z3t = np.dot(np.concatenate((y_l[:,l,:], a_l[:,l,:]), axis=-1), z3tran[l])
            z4t = np.dot(y_l[:,l,:], z4tran[l])
            zt = np.concatenate([z1t, z2t, z3t, z4t], axis=-1)
            y_t += leaky_ReLU(zt, negSlope)
        y_t = leaky_ReLU(y_t, negSlope)
        print(y_t.shape)
        yt.append(y_t)
        rt.append(leaky_ReLU(np.dot(y_t[:,:z1_dim+z2_dim], rA), negSlope))
        # Mixing function
        mixedDat = np.copy(y_t)
        for l in range(Nlayer - 1):
            mixedDat = leaky_ReLU(mixedDat, negSlope)
            mixedDat = np.dot(mixedDat, mixingList[l])
        x_t = np.copy(mixedDat)
        xt.append(x_t)
        y_l = np.concatenate((y_l, y_t[:,np.newaxis,:]),axis=1)[:,1:,:]
        a_l = np.concatenate((a_l, a_t[:, np.newaxis,:]),axis=1)[:,1:,:]

    yt = np.array(yt).transpose(1,0,2)
    xt = np.array(xt).transpose(1,0,2)
    at = np.array(at).transpose(1,0,2)
    rt = np.array(rt).transpose(1,0,2)

    np.savez(os.path.join(path, "data"), 
            yt = yt, 
            xt = xt,
            at = at,
            rt = rt)

    for l in range(lags):
        np.save(os.path.join(path, "Z1%d"%(lags-l)), z1tran[l])    
        np.save(os.path.join(path, "Z2%d"%(lags-l)), z2tran[l])
        np.save(os.path.join(path, "Z3%d"%(lags-l)), z3tran[l])
        np.save(os.path.join(path, "Z4%d"%(lags-l)), z4tran[l]) 

def non_invertible_noisecoupled_gaussian_ts_reward_action(action_num=10):
    lags = 2
    Nlayer = 2
    length = 1
    observation_dim = 32
    z1condList, z2condList, z3condList, z4condList, rcondList, mixingcondList = [], [], [], [], [], []
    negSlope = 0.2
    z1_dim, z2_dim, z3_dim, z4_dim = 2, 2, 2, 2
    latent_size = z1_dim + z2_dim + z3_dim + z4_dim 
    transitions = []
    z1tran, z2tran, z3tran, z4tran = [], [], [], []
    noise_scale = 0.1
    batch_size = 100000
    # batch_size = 100
    Niter4condThresh = 1e4

    path = os.path.join(root_dir, f"noisecoupled_gaussian_ts_2lag_IFactor_{action_num}_actions")
    os.makedirs(path, exist_ok=True)

    for i in range(int(Niter4condThresh)):
        # A = np.random.uniform(0,1, (Ncomp, Ncomp))
        z1A = np.random.uniform(1, 2, (z1_dim+z2_dim+1, z1_dim))  # - 1
        z2A = np.random.uniform(1, 2, (z1_dim+z2_dim, z2_dim))  # - 1
        z3A = np.random.uniform(1, 2, (latent_size+1, z3_dim))  # - 1
        z4A = np.random.uniform(1, 2, (latent_size, z4_dim))  # - 1
        mixingA = np.random.uniform(1, 2, (latent_size, observation_dim))  # - 1
        rA = np.random.uniform(1, 2, (z1_dim+z2_dim, 1))  # - 1
        for i in range(z1_dim):
            z1A[:, i] /= np.sqrt((z1A[:, i] ** 2).sum())
        for i in range(z2_dim):
            z2A[:, i] /= np.sqrt((z2A[:, i] ** 2).sum())
        for i in range(z3_dim):
            z3A[:, i] /= np.sqrt((z3A[:, i] ** 2).sum())
        for i in range(z4_dim):
            z4A[:, i] /= np.sqrt((z4A[:, i] ** 2).sum())
        for i in range(1):
            rA[:, i] /= np.sqrt((rA[:, i] ** 2).sum())
        for i in range(observation_dim):
            mixingA[:, i] /= np.sqrt((mixingA[:, i] ** 2).sum())

        z1condList.append(np.linalg.cond(z1A))
        z2condList.append(np.linalg.cond(z2A))
        z3condList.append(np.linalg.cond(z3A))
        z4condList.append(np.linalg.cond(z4A))
        rcondList.append(np.linalg.cond(rA))
        mixingcondList.append(np.linalg.cond(mixingA))
    print('get threshold')
    
    z1condThresh = np.percentile(z1condList, 25)  # only accept those below 25% percentile
    z2condThresh = np.percentile(z2condList, 25)
    z3condThresh = np.percentile(z3condList, 25)
    z4condThresh = np.percentile(z4condList, 25)
    rcondThresh = np.percentile(rcondList, 25)
    mixingcondThresh = np.percentile(mixingcondList, 25)

    # reward
    rB = generateUniformMatS((z1_dim+z2_dim, 1), rcondThresh)

    # mixing
    mixingB = generateUniformMatS((latent_size, observation_dim), mixingcondThresh)

    # transitions
    for l in range(lags):
        z1B = generateUniformMatS((z1_dim+z2_dim+1, z1_dim), z1condThresh)
        z2B = generateUniformMatS((z1_dim+z2_dim, z2_dim), z2condThresh)
        z3B = generateUniformMatS((latent_size+1, z3_dim), z3condThresh)
        z4B = generateUniformMatS((latent_size, z4_dim), z4condThresh)
        z1tran.append(z1B)
        z2tran.append(z2B)
        z3tran.append(z3B)
        z4tran.append(z4B)
    z1tran.reverse()
    z2tran.reverse()
    z3tran.reverse()
    z4tran.reverse()
    print('get transition matrix')
    mixingList, rewmixingList = [], []
    for l in range(Nlayer - 1):
        # generate causal matrix first:
        A = ortho_group.rvs(latent_size)  # generateUniformMat(Ncomp, condThresh)
        R = ortho_group.rvs(z1_dim+z2_dim)
        mixingList.append(A)
        rewmixingList.append(R)
    mixingList.append(mixingB)
    rewmixingList.append(rB)

    # Generate 10 random numbers from a standard normal distribution
    random_action = np.random.randn(action_num)

    # Generate a 3D array (5, 5, 5) with elements randomly chosen from the generated numbers
    y_l = np.random.normal(0, 1, (batch_size, lags, latent_size))
    a_l = np.random.choice(random_action, (batch_size, lags, 1))

    y_l = (y_l - np.mean(y_l, axis=0 ,keepdims=True)) / np.std(y_l, axis=0 ,keepdims=True)

    yt = []; xt = []; at = []; rt = []
    for i in range(lags):
        yt.append(y_l[:,i,:])
        at.append(a_l[:,i,:])
        rt.append(leaky_ReLU(np.dot(y_l[:,i, :z1_dim+z2_dim], rB), negSlope))
    # reward
    # rmixedDat = np.copy(y_l[:, :, :z1_dim+z2_dim])
    # for l in range(Nlayer - 1):
    #     rmixedDat = np.dot(rmixedDat, rewmixingList[l])
    #     rmixedDat = leaky_ReLU(rmixedDat, negSlope)
    # rmixedDat = np.dot(rmixedDat, rewmixingList[-1])
    # r_l = np.copy(rmixedDat)
    # for i in range(lags):
    #     rt.append(r_l[:,i,:])

    # observation
    mixedDat = np.copy(y_l)
    for l in range(Nlayer - 1):
        mixedDat = np.dot(mixedDat, mixingList[l])
        mixedDat = leaky_ReLU(mixedDat, negSlope)
    mixedDat = np.dot(mixedDat, mixingList[-1])

    x_l = np.copy(mixedDat)
    for i in range(lags):
        xt.append(x_l[:,i,:])

    # Mixing function
    for i in range(length):
        # Transition function
        y_t = np.random.normal(0, 0.2 * noise_scale, (batch_size, latent_size))
        # Modulate the noise scale with averaged history
        # print('original', y_t[0])
        # y_t = y_t * np.mean(y_l, axis=1)
        # print('after', y_t[0])
        a_t = np.random.choice(random_action, (batch_size, 1))
        print(a_t.shape)
        at.append(a_t)
        for l in range(lags):
            z1t = np.dot(np.concatenate((y_l[:,l,:z1_dim+z2_dim], a_l[:,l,:]), axis=-1), z1tran[l])
            z2t = np.dot(y_l[:,l,:z1_dim+z2_dim], z2tran[l])
            z3t = np.dot(np.concatenate((y_l[:,l,:], a_l[:,l,:]), axis=-1), z3tran[l])
            z4t = np.dot(y_l[:,l,:], z4tran[l])
            zt = np.concatenate([z1t, z2t, z3t, z4t], axis=-1)
            y_t += leaky_ReLU(zt, negSlope)
        y_t = leaky_ReLU(y_t, negSlope)
        print(y_t.shape)
        yt.append(y_t)
        # reward
        # rmixedDat = np.copy(y_t[:, :z1_dim+z2_dim])
        # rmixedDat = np.copy(y_t)
        # for l in range(Nlayer - 1):
        #     rmixedDat = np.dot(rmixedDat, rewmixingList[l])
        #     rmixedDat = leaky_ReLU(rmixedDat, negSlope)
        # rmixedDat = np.dot(rmixedDat, rewmixingList[-1])
        # r_t = np.copy(rmixedDat)
        # rt.append(r_t)
        rt.append(leaky_ReLU(np.dot(y_t[:,:z1_dim+z2_dim], rB), negSlope))
        # observation
        mixedDat = np.copy(y_t)
        for l in range(Nlayer - 1):
            mixedDat = np.dot(mixedDat, mixingList[l])
            mixedDat = leaky_ReLU(mixedDat, negSlope)
        mixedDat = np.dot(mixedDat, mixingList[-1])
        x_t = np.copy(mixedDat)
        xt.append(x_t)
        y_l = np.concatenate((y_l, y_t[:,np.newaxis,:]),axis=1)[:,1:,:]
        a_l = np.concatenate((a_l, a_t[:, np.newaxis,:]),axis=1)[:,1:,:]

    yt = np.array(yt).transpose(1,0,2)
    xt = np.array(xt).transpose(1,0,2)
    at = np.array(at).transpose(1,0,2)
    rt = np.array(rt).transpose(1,0,2)

    print(yt[:1,:,:])
    print(xt[:1,:,:])
    print(at[:1,:,:])
    print(rt[:1,:,:])
    np.savez(os.path.join(path, "data"), 
            yt = yt, 
            xt = xt,
            at = at,
            rt = rt)

    for l in range(lags):
        np.save(os.path.join(path, "Z1%d"%(lags-l)), z1tran[l]) 
        np.save(os.path.join(path, "Z2%d"%(lags-l)), z2tran[l])
        np.save(os.path.join(path, "Z3%d"%(lags-l)), z3tran[l])
        np.save(os.path.join(path, "Z4%d"%(lags-l)), z4tran[l]) 

def noisecoupled_gaussian_ts_reward_one_hot_action():
    lags = 2
    Nlayer = 3
    length = 1
    z1condList, z2condList, z3condList, z4condList, rcondList = [], [], [], [], []
    negSlope = 0.2
    z1_dim, z2_dim, z3_dim, z4_dim = 2, 2, 2, 2
    latent_size = z1_dim + z2_dim + z3_dim + z4_dim 
    action_size = 5
    transitions = []
    z1tran, z2tran, z3tran, z4tran = [], [], [], []
    noise_scale = 0.1
    batch_size = 100000
    # batch_size = 100
    Niter4condThresh = 1e4

    path = os.path.join(root_dir, "noisecoupled_gaussian_ts_2lag_IFactor_one_hot")
    os.makedirs(path, exist_ok=True)

    for i in range(int(Niter4condThresh)):
        # A = np.random.uniform(0,1, (Ncomp, Ncomp))
        z1A = np.random.uniform(1, 2, (z1_dim+z2_dim+action_size, z1_dim))  # - 1
        z2A = np.random.uniform(1, 2, (z1_dim+z2_dim, z2_dim))  # - 1
        z3A = np.random.uniform(1, 2, (latent_size+action_size, z3_dim))  # - 1
        z4A = np.random.uniform(1, 2, (latent_size, z4_dim))  # - 1
        rA = np.random.uniform(1, 2, (z1_dim+z2_dim, 1))  # - 1
        for i in range(z1_dim):
            z1A[:, i] /= np.sqrt((z1A[:, i] ** 2).sum())
        for i in range(z2_dim):
            z2A[:, i] /= np.sqrt((z2A[:, i] ** 2).sum())
        for i in range(z3_dim):
            z3A[:, i] /= np.sqrt((z3A[:, i] ** 2).sum())
        for i in range(z4_dim):
            z4A[:, i] /= np.sqrt((z4A[:, i] ** 2).sum())
        
        z1condList.append(np.linalg.cond(z1A))
        z2condList.append(np.linalg.cond(z2A))
        z3condList.append(np.linalg.cond(z3A))
        z4condList.append(np.linalg.cond(z4A))
    print('get threshold')
    # reward
    rA = np.random.uniform(1, 2, (z1_dim+z2_dim, 1))  # - 1
    rA[:, 0] /= np.sqrt((rA[:, 0] ** 2).sum())
    # transitions
    z1condThresh = np.percentile(z1condList, 25)  # only accept those below 25% percentile
    z2condThresh = np.percentile(z2condList, 25)
    z3condThresh = np.percentile(z3condList, 25)
    z4condThresh = np.percentile(z4condList, 25)
    
    for l in range(lags):
        z1B = generateUniformMatS((z1_dim+z2_dim+action_size, z1_dim), z1condThresh)
        z2B = generateUniformMatS((z1_dim+z2_dim, z2_dim), z2condThresh)
        z3B = generateUniformMatS((latent_size+action_size, z3_dim), z3condThresh)
        z4B = generateUniformMatS((latent_size, z4_dim), z4condThresh)
        z1tran.append(z1B)
        z2tran.append(z2B)
        z3tran.append(z3B)
        z4tran.append(z4B)
    z1tran.reverse()
    z2tran.reverse()
    z3tran.reverse()
    z4tran.reverse()
    print('get transition matrix')
    # 
    mixingList, rewmixingList = [], []
    for l in range(Nlayer - 1):
        # generate causal matrix first:
        A = ortho_group.rvs(latent_size)  # generateUniformMat(Ncomp, condThresh)
        B = ortho_group.rvs(latent_size) 
        mixingList.append(A)

    y_l = np.random.normal(0, 1, (batch_size, lags, latent_size))
    a_l = get_one_hot_ndarray(batch_size, lags, action_size)

    y_l = (y_l - np.mean(y_l, axis=0 ,keepdims=True)) / np.std(y_l, axis=0 ,keepdims=True)

    yt = []; xt = []; at = []; rt = []
    for i in range(lags):
        yt.append(y_l[:,i,:])
        at.append(a_l[:,i,:])
        rt.append(leaky_ReLU(np.dot(y_l[:,i,:z1_dim+z2_dim], rA), negSlope))
    mixedDat = np.copy(y_l)
    for l in range(Nlayer - 1):
        mixedDat = leaky_ReLU(mixedDat, negSlope)
        mixedDat = np.dot(mixedDat, mixingList[l])
    x_l = np.copy(mixedDat)
    for i in range(lags):
        xt.append(x_l[:,i,:])
        
    # Mixing function
    for i in range(length):
        # Transition function
        y_t = np.random.normal(0, noise_scale, (batch_size, latent_size))
        # Modulate the noise scale with averaged history
        # y_t = y_t * np.mean(y_l, axis=1)
        a_t = np.squeeze(get_one_hot_ndarray(batch_size, 1, action_size))
        print(a_t.shape)
        at.append(a_t)
        for l in range(lags):
            z1t = np.dot(np.concatenate((y_l[:,l,:z1_dim+z2_dim], a_l[:,l,:]), axis=-1), z1tran[l])
            z2t = np.dot(y_l[:,l,:z1_dim+z2_dim], z2tran[l])
            z3t = np.dot(np.concatenate((y_l[:,l,:], a_l[:,l,:]), axis=-1), z3tran[l])
            z4t = np.dot(y_l[:,l,:], z4tran[l])
            zt = np.concatenate([z1t, z2t, z3t, z4t], axis=-1)
            y_t += leaky_ReLU(zt, negSlope)
        y_t = leaky_ReLU(y_t, negSlope)
        print(y_t.shape)
        yt.append(y_t)
        rt.append(leaky_ReLU(np.dot(y_t[:,:z1_dim+z2_dim], rA), negSlope))
        # Mixing function
        mixedDat = np.copy(y_t)
        for l in range(Nlayer - 1):
            mixedDat = leaky_ReLU(mixedDat, negSlope)
            mixedDat = np.dot(mixedDat, mixingList[l])
        x_t = np.copy(mixedDat)
        xt.append(x_t)
        y_l = np.concatenate((y_l, y_t[:,np.newaxis,:]),axis=1)[:,1:,:]
        a_l = np.concatenate((a_l, a_t[:, np.newaxis,:]),axis=1)[:,1:,:]

    yt = np.array(yt).transpose(1,0,2)
    xt = np.array(xt).transpose(1,0,2)
    at = np.array(at).transpose(1,0,2)
    rt = np.array(rt).transpose(1,0,2)

    np.savez(os.path.join(path, "data"), 
            yt = yt, 
            xt = xt,
            at = at,
            rt = rt)

    for l in range(lags):
        np.save(os.path.join(path, "Z1%d"%(lags-l)), z1tran[l])    
        np.save(os.path.join(path, "Z2%d"%(lags-l)), z2tran[l])
        np.save(os.path.join(path, "Z3%d"%(lags-l)), z3tran[l])
        np.save(os.path.join(path, "Z4%d"%(lags-l)), z4tran[l]) 

if __name__ == "__main__":
    noisecoupled_gaussian_ts_reward_one_hot_action()