#Create variate same with and without event (50/50)
def create_dataset(n, seq_len, key_step, uniform_change=0.7, scale_param=0.0):
    level = 0
    drop = 0.7
    data = []
    for i in range(n):
        trend = np.random.uniform(0.1, -0.1)     
        change = np.random.uniform(-uniform_change, uniform_change)  
        key_step2 = np.random.randint(22,29)

        ts = []
        step = level
        for j in range(seq_len):
            ts.append(step)
            step += trend
            if (i >= n//2) & (j == key_step):
                step -= drop
            if j == key_step2:
                step -= change

        for k in range(seq_len):
            ts[k] += np.random.normal(0., scale_param) 

        data.append(ts)
            
    data = np.array(data).reshape(n, seq_len, 1)
    labels = np.concatenate((np.repeat(0., n//2), np.repeat(1., n - n//2))).reshape(n,1)

    return labels, data


#Create pars of time series both with and without event, but equal in the underlying generating process except for the event.
def create_dataset_counterfactuals(n, seq_len, key_step, uniform_change=0.7, scale_param=0.0):
    level = 0
    drop = 0.7
    data_0 = []
    data_1 = []
    for i in range(n):
        trend = np.random.uniform(0.1, -0.1)     
        change = np.random.uniform(-uniform_change, uniform_change) 
        key_step2 = np.random.randint(22,29)  

        ts0 = []
        ts1 = []

        step0 = step1 = level
        for j in range(seq_len):
            ts0.append(step0)
            ts1.append(step1)
            step0 += trend
            step1 += trend
            if j == key_step:
                step1 -= drop
            if j == key_step2:
                step0 -= change
                step1 -= change

        for k in range(seq_len):
            noise = np.random.normal(0., scale_param) 
            ts0[k] += noise
            ts1[k] += noise

        data_0.append(ts0)
        data_1.append(ts1)

    data_0 = np.array(data_0).reshape(n, seq_len, 1)
    data_1 = np.array(data_1).reshape(n, seq_len, 1)
    return data_0, data_1