import os
import numpy as np
import torch
from torch import nn
from torch.nn import init
from sklearn import preprocessing
from scipy.stats import ortho_group

VALIDATION_RATIO = 0.2
root_dir = './datasets'
standard_scaler = preprocessing.StandardScaler()

def leaky_ReLU_1d(d, negSlope):
    if d > 0:
        return d
    else:
        return d * negSlope

leaky1d = np.vectorize(leaky_ReLU_1d)

def leaky_ReLU(D, negSlope):
    assert negSlope > 0
    return leaky1d(D, negSlope)

def weigth_init(m):
    if isinstance(m, nn.Conv2d):
        init.xavier_uniform_(m.weight.data)
        init.constant_(m.bias.data,0.1)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0,0.01)
        m.bias.data.zero_()

def sigmoidAct(x):
    return 1. / (1 + np.exp(-1 * x))

def generateUniformMat(Ncomp, condT):
    """
    generate a random matrix by sampling each element uniformly at random
    check condition number versus a condition threshold
    """
    A = np.random.uniform(0, 2, (Ncomp, Ncomp)) - 1
    for i in range(Ncomp):
        A[:, i] /= np.sqrt((A[:, i] ** 2).sum())

    while np.linalg.cond(A) > condT:
        # generate a new A matrix!
        A = np.random.uniform(0, 2, (Ncomp, Ncomp)) - 1
        for i in range(Ncomp):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())

    return A

def noisecoupled_gaussian_ts(lags=2, length=1):

    Nlayer = 3
    condList = []
    negSlope = 0.2
    latent_size = 8
    transitions = []
    noise_scale = 0.1
    noise_scale_emission = 0
    batch_size = 100000
    Niter4condThresh = 1e4

    #path = os.path.join(root_dir, f"noisecoupled_gaussian_ts_{lags}lag_{length}length_{noise_scale_emission}emissionnoise")
    path = os.path.join(root_dir, f"noisecoupled_gaussian_ts_{lags}lag")
    os.makedirs(path, exist_ok=True)

    for i in range(int(Niter4condThresh)):
        # A = np.random.uniform(0,1, (Ncomp, Ncomp))
        A = np.random.uniform(1, 2, (latent_size, latent_size))  # - 1
        for i in range(latent_size):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())
        condList.append(np.linalg.cond(A))

    condThresh = np.percentile(condList, 25)  # only accept those below 25% percentile
    for l in range(lags):
        B = generateUniformMat(latent_size, condThresh)
        transitions.append(B)
    transitions.reverse()

    mixingList = []
    for l in range(Nlayer - 1):
        # generate causal matrix first:
        A = ortho_group.rvs(latent_size)  # generateUniformMat(Ncomp, condThresh)
        mixingList.append(A)

    # y is the latent variable and x is the observed variable
    y_l = np.random.normal(0, 1, (batch_size, lags, latent_size))
    y_l = (y_l - np.mean(y_l, axis=0 ,keepdims=True)) / np.std(y_l, axis=0 ,keepdims=True)

    x_noise = np.random.normal(0, noise_scale_emission, (batch_size, lags+length, latent_size))

    yt = []; xt = []
    for i in range(lags):
        yt.append(y_l[:,i,:])
    mixedDat = np.copy(y_l)
    for l in range(Nlayer - 1):
        mixedDat = leaky_ReLU(mixedDat, negSlope)
        mixedDat = np.dot(mixedDat, mixingList[l])
    x_l = np.copy(mixedDat)
    for i in range(lags):
        xt.append(x_l[:,i,:] + x_noise[:,i,:])
        
    # Mixing function
    for i in range(length):
        # Transition function
        y_t = np.random.normal(0, noise_scale, (batch_size, latent_size))
        # Modulate the noise scale with averaged history
        y_t = y_t * np.mean(y_l, axis=1)
        for l in range(lags):
            y_t += leaky_ReLU(np.dot(y_l[:,l,:], transitions[l]), negSlope)
        y_t = leaky_ReLU(y_t, negSlope)
        yt.append(y_t)
        # Mixing function
        mixedDat = np.copy(y_t)
        for l in range(Nlayer - 1):
            mixedDat = leaky_ReLU(mixedDat, negSlope)
            mixedDat = np.dot(mixedDat, mixingList[l])
        x_t = np.copy(mixedDat)
        xt.append(x_t+x_noise[:,lags+i,:])
        # TODO: need explain
        y_l = np.concatenate((y_l, y_t[:,np.newaxis,:]),axis=1)[:,1:,:]

    yt = np.array(yt).transpose(1,0,2); xt = np.array(xt).transpose(1,0,2)

    np.savez(os.path.join(path, "data"), 
            yt = yt, 
            xt = xt)

    for l in range(lags):
        B = transitions[l]
        np.save(os.path.join(path, "W%d"%(lags-l)), B)

def noisecoupled_gaussian_laplacetemporal_ns_ts(lags=2, length=1, lookbackstep=1, noise_scale_emission=0):
    Nlayer = 3
    condList = []
    negSlope = 0.2
    latent_size = 8
    transitions = []
    noise_scale = 0.1
    c_noise_scale = 0.05
    batch_size = 100000
    Niter4condThresh = 1e4
    c_latent_size = int(latent_size/2)


    path = os.path.join(root_dir, f"noisecoupled_gaussian_laplacetemporal_ns_ts_{lags}lag_{length}length_{lookbackstep}lookbackstep")
    os.makedirs(path, exist_ok=True)


    length = length + lookbackstep

    for i in range(int(Niter4condThresh)):
        # A = np.random.uniform(0,1, (Ncomp, Ncomp))
        A = np.random.uniform(1, 2, (latent_size, latent_size))  # - 1
        for i in range(latent_size):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())
        condList.append(np.linalg.cond(A))

    condThresh = np.percentile(condList, 25)  # only accept those below 25% percentile
    for l in range(lags):
        B = generateUniformMat(latent_size, condThresh)
        transitions.append(B)
    transitions.reverse()

    mixingList = []
    for l in range(Nlayer - 1):
        # generate causal matrix first:
        A = ortho_group.rvs(latent_size)  # generateUniformMat(Ncomp, condThresh)
        mixingList.append(A)

    y_l = np.random.normal(0, 1, (batch_size, lags, latent_size))
    y_l = (y_l - np.mean(y_l, axis=0, keepdims=True)) / np.std(y_l, axis=0, keepdims=True)

    x_noise = np.random.normal(0, noise_scale_emission, (batch_size, lags + length, latent_size))

    yt = []
    xt = []
    ct = []
    for i in range(lags):
        yt.append(y_l[:, i, :])
    mixedDat = np.copy(y_l)
    for l in range(Nlayer - 1):
        mixedDat = leaky_ReLU(mixedDat, negSlope)
        mixedDat = np.dot(mixedDat, mixingList[l])
    x_l = np.copy(mixedDat)
    c_t = np.random.uniform(0, 1, (batch_size, c_latent_size))
    for i in range(lags):
        xt.append(x_l[:, i, :] + x_noise[:, i, :])
        ct.append(c_t.copy())


    # Mixing function
    for i in range(length):
        # Transition of c
        c_t += torch.distributions.laplace.Laplace(0,c_noise_scale).rsample((batch_size, c_latent_size)).numpy()
        ct.append(c_t.copy())
        # map c_t to latent space of y_t
        c_t_proj = np.tile(c_t,2)
        # Transition function
        y_t = np.random.normal(0, noise_scale, (batch_size, latent_size))
        # Modulate the noise scale with averaged history + the slow transition of c
        y_t = y_t * 0.5 * (np.mean(y_l, axis=1) + c_t_proj)
        for l in range(lags):
            y_t += leaky_ReLU(np.dot(y_l[:, l, :], transitions[l]), negSlope)
        y_t = leaky_ReLU(y_t, negSlope)
        yt.append(y_t)
        # Mixing function
        mixedDat = np.copy(y_t)
        for l in range(Nlayer - 1):
            mixedDat = leaky_ReLU(mixedDat, negSlope)
            mixedDat = np.dot(mixedDat, mixingList[l])
        x_t = np.copy(mixedDat) + x_noise[:, lags+i, :]
        xt.append(x_t)
        # TODO: need explain
        y_l = np.concatenate((y_l, y_t[:, np.newaxis, :]), axis=1)[:, 1:, :]

    yt = np.array(yt).transpose(1, 0, 2)
    xt = np.array(xt).transpose(1, 0, 2)
    ct = np.array(ct).transpose(1, 0, 2)

    np.savez(os.path.join(path, "data"),
             yt=yt,
             xt=xt,
             ct=ct)

    for l in range(lags):
        B = transitions[l]
        np.save(os.path.join(path, "W%d" % (lags - l)), B)

def noisecoupled_gaussian_2laplacetemporal_ns_ts(lags=2, length=1, lookbackstep=1, noise_scale_emission=0):
    Nlayer = 3
    condList = []
    negSlope = 0.2
    latent_size = 8
    transitions = []
    noise_scale = 0.1
    c_noise_scale = 0.05
    batch_size = 100000
    Niter4condThresh = 1e4
    c_latent_size = int(latent_size/2)

    path = os.path.join(root_dir, f"noisecoupled_gaussian_2laplacetemporal_ns_ts_{lags}lag_{length}length_{lookbackstep}lookbackstep")
    os.makedirs(path, exist_ok=True)


    length = length + lookbackstep

    for i in range(int(Niter4condThresh)):
        # A = np.random.uniform(0,1, (Ncomp, Ncomp))
        A = np.random.uniform(1, 2, (latent_size, latent_size))  # - 1
        for i in range(latent_size):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())
        condList.append(np.linalg.cond(A))

    condThresh = np.percentile(condList, 25)  # only accept those below 25% percentile
    for l in range(lags):
        B = generateUniformMat(latent_size, condThresh)
        transitions.append(B)
    transitions.reverse()
    C = generateUniformMat(latent_size, condThresh)


    mixingList = []
    for l in range(Nlayer - 1):
        # generate causal matrix first:
        A = ortho_group.rvs(latent_size)  # generateUniformMat(Ncomp, condThresh)
        mixingList.append(A)

    y_l = np.random.normal(0, 1, (batch_size, lags, latent_size))
    y_l = (y_l - np.mean(y_l, axis=0, keepdims=True)) / np.std(y_l, axis=0, keepdims=True)

    x_noise = np.random.normal(0, noise_scale_emission, (batch_size, lags + length, latent_size))

    yt = []
    xt = []
    ct = []
    for i in range(lags):
        yt.append(y_l[:, i, :])
    mixedDat = np.copy(y_l)
    for l in range(Nlayer - 1):
        mixedDat = leaky_ReLU(mixedDat, negSlope)
        mixedDat = np.dot(mixedDat, mixingList[l])
    x_l = np.copy(mixedDat)
    c_t = np.random.uniform(0, 1, (batch_size, c_latent_size))
    for i in range(lags):
        xt.append(x_l[:, i, :] + x_noise[:, i, :])
        ct.append(c_t.copy())


    # Mixing function
    for i in range(length):
        # Transition of c
        c_t += torch.distributions.laplace.Laplace(0,c_noise_scale).rsample((batch_size, c_latent_size)).numpy()
        ct.append(c_t.copy())
        # map c_t to latent space of y_t
        c_t_proj = np.tile(c_t,2)
        # Transition function
        y_t = np.random.normal(0, noise_scale, (batch_size, latent_size))
        # Modulate the noise scale with averaged history
        y_t = y_t * 0.5 * (np.mean(y_l, axis=1))
        y_t += leaky_ReLU(np.dot(c_t_proj, C), negSlope)
        for l in range(lags):
            y_t += leaky_ReLU(np.dot(y_l[:, l, :], transitions[l]), negSlope)
        y_t = leaky_ReLU(y_t, negSlope)
        yt.append(y_t)
        # Mixing function
        mixedDat = np.copy(y_t)
        for l in range(Nlayer - 1):
            mixedDat = leaky_ReLU(mixedDat, negSlope)
            mixedDat = np.dot(mixedDat, mixingList[l])
        x_t = np.copy(mixedDat)
        xt.append(x_t + x_noise[:, lags+i, :])
        # TODO: need explain
        y_l = np.concatenate((y_l, y_t[:, np.newaxis, :]), axis=1)[:, 1:, :]

    yt = np.array(yt).transpose(1, 0, 2)
    xt = np.array(xt).transpose(1, 0, 2)
    ct = np.array(ct).transpose(1, 0, 2)

    np.savez(os.path.join(path, "data"),
             yt=yt,
             xt=xt,
             ct=ct)

    for l in range(lags):
        B = transitions[l]
        np.save(os.path.join(path, "W%d" % (lags - l)), B)

def noisecoupled_gaussian_laplacelineartemporal_ns_ts(lags=2, length=1, lookbackstep=1, noise_scale_emission=0):
    Nlayer = 3
    condList = []
    negSlope = 0.2
    latent_size = 8
    transitions = []
    noise_scale = 0.1
    c_noise_scale = 0.05
    batch_size = 100000
    Niter4condThresh = 1e4
    c_latent_size = int(latent_size/2)

    path = os.path.join(root_dir, f"noisecoupled_gaussian_linearlaplacetemporal_ns_ts_{lags}lag_{length}length_{lookbackstep}lookbackstep")
    os.makedirs(path, exist_ok=True)


    length = length + lookbackstep

    # generate transition matrix for z
    for i in range(int(Niter4condThresh)):
        # A = np.random.uniform(0,1, (Ncomp, Ncomp))
        A = np.random.uniform(1, 2, (latent_size, latent_size))  # - 1
        for i in range(latent_size):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())
        condList.append(np.linalg.cond(A))

    condThresh = np.percentile(condList, 25)  # only accept those below 25% percentile
    for l in range(lags):
        B = generateUniformMat(latent_size, condThresh)
        transitions.append(B)
    transitions.reverse()

    # generate linear transition matrix for c
    for i in range(int(Niter4condThresh)):
        # A = np.random.uniform(0,1, (Ncomp, Ncomp))
        A = np.random.uniform(1, 2, (c_latent_size, c_latent_size))  # - 1
        for i in range(c_latent_size):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())
        condList.append(np.linalg.cond(A))

    condThresh = np.percentile(condList, 25)  # only accept those below 25% percentile

    C = generateUniformMat(c_latent_size, condThresh)

    mixingList = []
    for l in range(Nlayer - 1):
        # generate causal matrix first:
        A = ortho_group.rvs(latent_size)  # generateUniformMat(Ncomp, condThresh)
        mixingList.append(A)

    y_l = np.random.normal(0, 1, (batch_size, lags, latent_size))
    y_l = (y_l - np.mean(y_l, axis=0, keepdims=True)) / np.std(y_l, axis=0, keepdims=True)

    x_noise = np.random.normal(0, noise_scale_emission, (batch_size, lags + length, latent_size))

    yt = []
    xt = []
    ct = []
    for i in range(lags):
        yt.append(y_l[:, i, :])
    mixedDat = np.copy(y_l)
    for l in range(Nlayer - 1):
        mixedDat = leaky_ReLU(mixedDat, negSlope)
        mixedDat = np.dot(mixedDat, mixingList[l])
    x_l = np.copy(mixedDat)
    c_t = np.random.uniform(0, 1, (batch_size, c_latent_size))
    for i in range(lags):
        xt.append(x_l[:, i, :] + x_noise[:, i, :])
        ct.append(c_t.copy())


    # Mixing function
    for i in range(length):
        # Transition of c
        c_t = np.dot(c_t, C)
        c_t += torch.distributions.laplace.Laplace(0,c_noise_scale).rsample((batch_size, c_latent_size)).numpy()
        ct.append(c_t.copy())
        # map c_t to latent space of y_t
        c_t_proj = np.tile(c_t,2)
        # Transition function
        y_t = np.random.normal(0, noise_scale, (batch_size, latent_size))
        # Modulate the noise scale with averaged history + the slow transition of c
        y_t = y_t * 0.5 * (np.mean(y_l, axis=1) + c_t_proj)
        for l in range(lags):
            y_t += leaky_ReLU(np.dot(y_l[:, l, :], transitions[l]), negSlope)
        y_t = leaky_ReLU(y_t, negSlope)
        yt.append(y_t)
        # Mixing function
        mixedDat = np.copy(y_t)
        for l in range(Nlayer - 1):
            mixedDat = leaky_ReLU(mixedDat, negSlope)
            mixedDat = np.dot(mixedDat, mixingList[l])
        x_t = np.copy(mixedDat)
        xt.append(x_t)
        # TODO: need explain
        y_l = np.concatenate((y_l, y_t[:, np.newaxis, :]), axis=1)[:, 1:, :]

    yt = np.array(yt).transpose(1, 0, 2)
    xt = np.array(xt).transpose(1, 0, 2)
    ct = np.array(ct).transpose(1, 0, 2)

    np.savez(os.path.join(path, "data"),
             yt=yt,
             xt=xt,
             ct=ct)

    for l in range(lags):
        B = transitions[l]
        np.save(os.path.join(path, "W%d" % (lags - l)), B)


def noisecoupled_gaussian_changets(lags=2, length=1, noise_scale_emission=0):
    Nlayer = 3
    ratio = 0.8
    condList = []
    negSlope = 0.2
    latent_size = 8
    transitions = []
    noise_scale = 0.1
    batch_size = 100000
    source_size = int(ratio * batch_size)
    target_size = batch_size - source_size
    Niter4condThresh = 1e4

    path = os.path.join(root_dir, f"noisecoupled_gaussian_changets_{ratio}ratio_{lags}lag_{length}length")
    os.makedirs(path, exist_ok=True)

    for i in range(int(Niter4condThresh)):
        # A = np.random.uniform(0,1, (Ncomp, Ncomp))
        A = np.random.uniform(1, 2, (latent_size, latent_size))  # - 1
        for i in range(latent_size):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())
        condList.append(np.linalg.cond(A))

    condThresh = np.percentile(condList, 25)  # only accept those below 25% percentile
    for l in range(lags):
        B = generateUniformMat(latent_size, condThresh)
        transitions.append(B)
    transitions.reverse()

    mixingList = []
    for l in range(Nlayer - 1):
        # generate causal matrix first:
        A = ortho_group.rvs(latent_size)  # generateUniformMat(Ncomp, condThresh)
        mixingList.append(A)

    y_l = np.random.normal(0, 1, (batch_size, lags, latent_size))
    y_l = (y_l - np.mean(y_l, axis=0, keepdims=True)) / np.std(y_l, axis=0, keepdims=True)

    x_noise = np.random.normal(0, noise_scale_emission, (batch_size, lags + length, latent_size))

    yt = [];
    xt = []
    for i in range(lags):
        yt.append(y_l[:, i, :])
    mixedDat = np.copy(y_l)
    for l in range(Nlayer - 1):
        mixedDat = leaky_ReLU(mixedDat, negSlope)
        mixedDat = np.dot(mixedDat, mixingList[l])
    x_l = np.copy(mixedDat)
    for i in range(lags):
        xt.append(x_l[:, i, :] + x_noise[:, i, :])

    # Domain-varying edges
    edge_pairs = [(1,2), (3,4)]
    source_edge_weights = np.random.uniform(-1,1,(1, len(edge_pairs)))
    target_edge_weights = np.random.uniform(-1.25, 1.25,(1, len(edge_pairs)))

    # Mixing function
    for i in range(length):
        # Transition function
        y_t = np.random.normal(0, noise_scale, (batch_size, latent_size))
        # Modulate the noise scale with averaged history
        y_t = y_t * np.mean(y_l, axis=1)
        for l in range(lags):
            if l == 0:
                for p_idx, pair in enumerate(edge_pairs):
                    transitions[0][pair[0], pair[1]] = source_edge_weights[0, p_idx]
                y_t[:source_size] += leaky_ReLU(np.dot(y_l[:source_size, 0, :], transitions[0]), negSlope)
                for p_idx, pair in enumerate(edge_pairs):
                    transitions[0][pair[0], pair[1]] = target_edge_weights[0, p_idx]
                y_t[source_size:] += leaky_ReLU(np.dot(y_l[source_size:, 0, :], transitions[0]), negSlope)
            else:
                y_t += leaky_ReLU(np.dot(y_l[:, l, :], transitions[l]), negSlope)
        y_t = leaky_ReLU(y_t, negSlope)
        yt.append(y_t)
        # Mixing function
        mixedDat = np.copy(y_t)
        for l in range(Nlayer - 1):
            mixedDat = leaky_ReLU(mixedDat, negSlope)
            mixedDat = np.dot(mixedDat, mixingList[l])
        x_t = np.copy(mixedDat)
        xt.append(x_t + x_noise[:, lags+i, :])
        # TODO: need explain
        y_l = np.concatenate((y_l, y_t[:, np.newaxis, :]), axis=1)[:, 1:, :]

    yt = np.array(yt).transpose(1, 0, 2);
    xt = np.array(xt).transpose(1, 0, 2)

    np.savez(os.path.join(path, "data"),
             yt=yt,
             xt=xt)

    for l in range(lags):
        B = transitions[l]
        np.save(os.path.join(path, "W%d" % (lags - l)), B)


if __name__ == "__main__":
    noisecoupled_gaussian_ts()

