import numpy as np
import os
import pdb


def contract(x, tau=1000, eps=1e-6):
    """
    Args:
        x: (S, T)
    """
    scale = np.amax(np.absolute(x))
    if scale < tau:
        y = x
    else:
        y = x / (scale + eps) * tau

    y = y / tau * 10.0
    return y


def load_data(full_path):
    """
    Returns:
        data: of shape (N, 3, 3000)
    """
    train_data = np.load(os.path.join(full_path, "train_data.npy"), allow_pickle=True)
    val_data = np.load(os.path.join(full_path, "val_data.npy"), allow_pickle=True)
    test_data = np.load(os.path.join(full_path, "test_data.npy"), allow_pickle=True)
    return train_data, val_data, test_data


def stride_data(array, length, stride):
    """
    array --> (num_examples x num sensors [e,n,z] x time [3000])
    """
    data_x = []
    N, S, T = array.shape
    for instance_idx in range(N):
        newarray = contract(array[instance_idx])
        for time_idx in range(0, T, stride):
            s_begin = time_idx
            s_end = s_begin + length
            if s_end <= array.shape[-1]:
                data_x.append(newarray[:, s_begin:s_end])

    x_arr = np.stack(data_x, axis=0)
    return x_arr


def run():
    full_path = "/data/dummy/mts_v2_datasets/earthquake_clean_randomsplit"
    save_path = "/data/dummy/mvts/seismo/ts1000_stride50"
    train_data, val_data, test_data = load_data(full_path)
    length = 1000
    stride = 50

    np.save(os.path.join(save_path, "train_time.npy"), train_data)
    np.save(os.path.join(save_path, "val_time.npy"), val_data)
    np.save(os.path.join(save_path, "test_time.npy"), test_data)
    """

    train_strided = stride_data(train_data, length, stride)
    val_strided = stride_data(val_data, length, stride)
    test_strided = stride_data(test_data, length, stride)

    np.save(os.path.join(save_path, "train_time.npy"), train_strided)
    np.save(os.path.join(save_path, "val_time.npy"), val_strided)
    np.save(os.path.join(save_path, "test_time.npy"), test_strided)
    """


if __name__ == "__main__":
    run()
