import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset


def nonlinear_ode(t, z, A, b, omega, beta):
    z = z.reshape(-1)
    linear = A @ z + b
    nonlinear = np.zeros_like(z)
    for i in range(len(z)):
        j = (i + 1) % len(z)
        nonlinear[i] = -beta * z[i] + omega * np.sin(z[j])
    return linear + nonlinear


def generate_multi_sequence_data(
        num_sequences=40, num_tokens=20, token_length=50,
        state_dim=3, total_time=10.0, seed=42, visualize=True
):
    """
    Generates multiple ODE trajectories that follow the same underlying dynamics.

    Each sequence shares the same ODE parameters (A, b, etc.) but starts
    from a different random initial condition z0, creating varied but similar curves.
    """
    print("Step 1: Generating multi-sequence data with shared dynamics...")
    np.random.seed(seed)

    total_points = num_tokens * token_length
    t_eval = np.linspace(0, total_time, total_points)

    # Define SHARED nonlinear ODE system parameters (defined once)
    Q = np.random.randn(state_dim, state_dim)
    eigs = -np.abs(np.random.rand(state_dim)) * 0.1
    A = Q @ np.diag(eigs) @ np.linalg.inv(Q)
    b = np.random.uniform(-0.5, 0.5, size=state_dim)
    omega = 3.0 # 1.0
    beta = 0.1  # 0.2

    all_x_sequences = []
    # Generate num_sequences independent trajectories
    for _ in range(num_sequences):
        z0 = np.random.uniform(-1, 1, size=state_dim)
        sol = solve_ivp(
            nonlinear_ode, [0, total_time], z0,
            t_eval=t_eval, args=(A, b, omega, beta)
        )
        all_x_sequences.append(sol.y.T)

    x_data = np.stack(all_x_sequences, axis=0)

    # --- NEW: Normalize the state data to the range [-1, 1] ---
    min_val = x_data.min()
    max_val = x_data.max()
    if max_val > min_val:
        x_data = 2 * (x_data - min_val) / (max_val - min_val) - 1
    # -----------------------------------------------------------

    t_norm = (t_eval - t_eval[0]) / (t_eval[-1] - t_eval[0])
    t_tokens_shared = t_norm.reshape(num_tokens, token_length)
    x_tokens_by_seq = x_data.reshape(
        num_sequences, num_tokens, token_length, state_dim
    )
    x_tokens_by_idx = x_tokens_by_seq.transpose(1, 0, 2, 3)
    t_tokens_by_idx = np.expand_dims(t_tokens_shared, axis=1)
    t_tokens_by_idx = np.tile(t_tokens_by_idx, (1, num_sequences, 1))

    if visualize:
        # --- FIGURE 1: Plot all states for multiple sequences ---
        plt.figure(figsize=(14, 7))
        num_to_plot = min(8, num_sequences)
        plt.title(f"All States for {num_to_plot} Different Generated Sequences (Normalized)")

        colors = plt.cm.jet(np.linspace(0, 1, num_to_plot))
        linestyles = ['-', '--', ':', '-.']

        for i in range(num_to_plot):
            full_sequence_data = x_data[i]
            for d in range(state_dim):
                style = linestyles[d % len(linestyles)]
                label = f'Seq {i + 1}, Dim {d}' if i == 0 else None
                plt.plot(t_norm, full_sequence_data[:, d],
                         color=colors[i], linestyle=style, label=label, alpha=0.7)

        plt.xlabel("Normalized Time [0, 1]")
        plt.ylabel("Normalized State")  # Updated Y-axis label
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # --- FIGURE 2: Visualize token segmentation for one sequence ---
        plt.figure(figsize=(14, 7))
        plt.title("Token Segmentation for a Single Representative Sequence (Normalized)")

        z_plot_tokens = x_tokens_by_idx[:, 0, :, :]

        color_cycle = ['blue', 'red']
        for i in range(num_tokens):
            color = color_cycle[i % 2]
            for d in range(state_dim):
                plt.plot(t_tokens_shared[i], z_plot_tokens[i, :, d], color=color)

        plt.xlabel("Normalized Time [0, 1]")
        plt.ylabel("Normalized State")  # Updated Y-axis label
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    return t_tokens_by_idx, x_tokens_by_idx


def prepare_dataloaders(t_tokens_by_idx, x_tokens_by_idx, batch_size):
    """
    Prepares a list of DataLoaders, one for each token time interval.

    This function takes the grouped token data and creates a separate DataLoader
    for each token index (e.g., one for all token 0s, one for all token 1s).
    This guarantees that every batch from a given loader contains instances
    that share the exact same time vector.

    Args:
        t_tokens_by_idx (np.ndarray): Time data from generate_multi_sequence_data.
        x_tokens_by_idx (np.ndarray): State data from generate_multi_sequence_data.
        batch_size (int): The batch size for the DataLoaders.

    Returns:
        list: A list of PyTorch DataLoader objects.
    """
    print("\nStep 2: Preparing specialized DataLoaders...")
    # Create (current, next) pairs by slicing along the token dimension
    x_curr_pairs = x_tokens_by_idx[:-1]
    t_curr_pairs = t_tokens_by_idx[:-1]
    x_next_pairs = x_tokens_by_idx[1:]
    t_next_pairs = t_tokens_by_idx[1:]

    num_token_groups = x_curr_pairs.shape[0]
    dataloaders = []

    for i in range(num_token_groups):
        # Get all data for the i-th token group
        x_curr_group = torch.tensor(x_curr_pairs[i], dtype=torch.float32)
        t_curr_group = torch.tensor(t_curr_pairs[i], dtype=torch.float32)
        m_curr_group = torch.ones_like(t_curr_group)

        x_next_group = torch.tensor(x_next_pairs[i], dtype=torch.float32)
        t_next_group = torch.tensor(t_next_pairs[i], dtype=torch.float32)
        m_next_group = torch.ones_like(t_next_group)

        # Create a dedicated Dataset and DataLoader for this group
        dataset = TensorDataset(
            x_curr_group, t_curr_group, m_curr_group,
            x_next_group, t_next_group, m_next_group
        )
        # Shuffle is safe because all items in this dataset have the same time
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        dataloaders.append(dataloader)

        # Print info for the first group to verify
        if i == 0:
            print(
                f" - Created DataLoader for token group 0 (time interval [{t_curr_group[0, 0]:.2f}, {t_curr_group[0, -1]:.2f}])")
            print(f"   Dataset size for this group: {len(dataset)} sequences.")

    print(
        f"\nSuccessfully created {len(dataloaders)} DataLoaders, one for each of the {num_token_groups} token time intervals.")
    return dataloaders



def generate_ode_tokens(num_tokens=5, token_length=50, state_dim=2,
                        total_time=10.0, seed=42, visualize=True):
    """
    Returns:
        t_tokens: [num_tokens, token_length] normalized time per token
        z_tokens: [num_tokens, token_length, state_dim] state values per token
        params: shared ODE parameters
    """
    np.random.seed(seed)
    total_points = num_tokens * token_length
    t_eval = np.linspace(0, total_time, total_points)

    # Shared nonlinear ODE system
    z0 = np.random.uniform(-1, 1, size=state_dim)
    Q = np.random.randn(state_dim, state_dim)
    eigs = -np.abs(np.random.rand(state_dim))
    A = Q @ np.diag(eigs) @ np.linalg.inv(Q)
    b = np.random.uniform(-0.5, 0.5, size=state_dim)
    omega = 1.0
    beta = 0.2

    # Solve ODE
    sol = solve_ivp(nonlinear_ode, [0, total_time], z0,
                    t_eval=t_eval, args=(A, b, omega, beta))
    full_t = sol.t
    full_z = sol.y.T  # [total_points, state_dim]

    # Normalize time to [0,1]
    t_norm = (full_t - full_t[0]) / (full_t[-1] - full_t[0])  # [T]

    # Slice into tokens
    t_tokens = t_norm.reshape(num_tokens, token_length)                  # [N, L]
    z_tokens = full_z.reshape(num_tokens, token_length, state_dim)      # [N, L, D]

    # Visualization

    if visualize:
        color_cycle = ['blue', 'red']  # two alternating colors
        plt.figure()
        for i in range(num_tokens):
            color = color_cycle[i % 2]
            for d in range(state_dim):
                plt.plot(t_tokens[i], z_tokens[i, :, d], color=color)
        plt.title("Nonlinear ODE trajectory with token segmentation")
        plt.xlabel("Normalized Time [0, 1]")
        plt.ylabel("State")
        plt.grid(True)
        plt.tight_layout()
        plt.show()


    # Return stacked arrays and shared parameters
    params = {'A': A, 'b': b, 'omega': omega, 'beta': beta, 'z0': z0}
    return t_tokens, z_tokens, params


def prepare_single_dataloader(t_tokens_by_idx, x_tokens_by_idx, batch_size):
    """
    Prepares a SINGLE DataLoader that shuffles all token pairs together.
    This creates batches where instances can come from different time intervals.
    """
    print("\nStep 2: Preparing a single, shuffled DataLoader...")

    # Create (current, next) pairs
    x_curr_pairs = x_tokens_by_idx[:-1]
    t_curr_pairs = t_tokens_by_idx[:-1]
    x_next_pairs = x_tokens_by_idx[1:]
    t_next_pairs = t_tokens_by_idx[1:]

    # --- KEY CHANGE: Reshape and Concatenate All Groups ---
    # Flatten all token groups into one long list of samples
    num_pairs, num_seq, token_len, state_dim = x_curr_pairs.shape

    all_x_curr = torch.tensor(x_curr_pairs.reshape(-1, token_len, state_dim), dtype=torch.float32)
    all_t_curr = torch.tensor(t_curr_pairs.reshape(-1, token_len), dtype=torch.float32)
    all_m_curr = torch.ones_like(all_t_curr)

    all_x_next = torch.tensor(x_next_pairs.reshape(-1, token_len, state_dim), dtype=torch.float32)
    all_t_next = torch.tensor(t_next_pairs.reshape(-1, token_len), dtype=torch.float32)
    all_m_next = torch.ones_like(all_t_next)

    # --- Create a single Dataset and DataLoader ---
    dataset = TensorDataset(
        all_x_curr, all_t_curr, all_m_curr,
        all_x_next, all_t_next, all_m_next
    )

    # Shuffle=True now creates the mixed-time batches you originally had
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    print(f"Successfully created a single DataLoader with {len(dataset)} total samples.")
    return dataloader


class ODESequenceDataModule:
    def __init__(
        self,
        num_sequences=40,
        seq_length=1000,
        state_dim=3,
        total_time=20.0,
        batch_size=8,
        seed=42,
        device='cpu'
    ):
        self.num_sequences = num_sequences
        self.seq_length = seq_length
        self.state_dim = state_dim
        self.total_time = total_time
        self.batch_size = batch_size
        self.seed = seed
        self.device = device

        self.t_data = None
        self.x_data = None

    def plot_sequences(self, num_to_plot=4):
        """
        Visualize the generated ODE sequences.
        Plots up to `num_to_plot` sequences from the dataset.
        """
        if self.t_data is None or self.x_data is None:
            raise ValueError("No data available. Run generate_data() first.")

        num_to_plot = min(num_to_plot, self.t_data.shape[0])
        T = self.t_data.shape[1]
        D = self.x_data.shape[2]

        plt.figure(figsize=(14, 6))
        plt.title(f"Sample Trajectories from ODE Generator (First {num_to_plot} Sequences)")

        colors = plt.cm.viridis(np.linspace(0, 1, D))
        linestyles = ['-', '--', ':', '-.']

        for i in range(num_to_plot):
            for d in range(D):
                plt.plot(
                    self.t_data[i].cpu(),
                    self.x_data[i, :, d].cpu(),
                    linestyle=linestyles[d % len(linestyles)],
                    color=colors[d],
                    label=f"Seq {i+1}, Dim {d}" if i == 0 else None,
                    alpha=0.8
                )

        plt.xlabel("Time (s)")
        plt.ylabel("State Value (Normalized)")
        plt.grid(True)
        plt.legend(loc="upper right")
        plt.tight_layout()
        plt.show()

    def generate_data(self):
        np.random.seed(self.seed)
        t_eval = np.linspace(0, 1.0, self.seq_length)  # ✅ Normalized time
        self.t_scalar = torch.tensor(t_eval, dtype=torch.float32).to(self.device)  # [T]
        self.t_data = (
            self.t_scalar.cpu()
            .unsqueeze(0)
            .repeat(self.num_sequences, 1)
            .to(self.device)
        )
        # Define shared dynamics
        Q = np.random.randn(self.state_dim, self.state_dim)
        eigs = -np.abs(np.random.rand(self.state_dim))
        A = Q @ np.diag(eigs) @ np.linalg.inv(Q)
        b = np.random.uniform(-0.5, 0.5, size=self.state_dim)
        omega, beta = 2.0, 0.05

        x_all = []
        for _ in range(self.num_sequences):
            z0 = np.random.uniform(0.9, 1, size=self.state_dim)
            # Use unnormalized time for solving
            sol = solve_ivp(self.nonlinear_ode, [0, self.total_time], z0,
                            t_eval=np.linspace(0, self.total_time, self.seq_length),
                            args=(A, b, omega, beta))
            x_all.append(sol.y.T)

        x_data = np.stack(x_all)  # [B, T, D]
        x_min, x_max = x_data.min(), x_data.max()
        x_data = 2 * (x_data - x_min) / (x_max - x_min) - 1

        self.x_data = torch.tensor(x_data, dtype=torch.float32).to(self.device)

    @staticmethod
    def nonlinear_ode(t, z, A, b, omega, beta):
        z = z.reshape(-1)
        linear = A @ z + b
        nonlinear = np.zeros_like(z)
        for i in range(len(z)):
            j = (i + 1) % len(z)
            nonlinear[i] = -beta * z[i] + omega * np.sin(z[j])
        return linear + nonlinear

    def create_mask(self, t_res=None, t_cutoff=None):
        """
        Create mask for multi-resolution and truncation.
        - t_res: keep every t_res steps (e.g., 4 for LR)
        - t_cutoff: mask out time points after this time (for shorter sequences)
        """
        if self.t_data is None:
            raise ValueError("Call generate_data() before create_mask()")

        B, T = self.t_data.shape
        t_data_cpu = self.t_data.cpu()  # Ensure on CPU

        mask = torch.ones_like(t_data_cpu)

        if t_res is not None and t_res > 1:
            keep_idx = torch.arange(0, T, t_res)
            res_mask = torch.zeros(T, dtype=torch.bool)
            res_mask[keep_idx] = True
            mask = mask * res_mask.unsqueeze(0)

        if t_cutoff is not None:
            cutoff_mask = (t_data_cpu <= t_cutoff).float()
            mask = mask * cutoff_mask

        mask = mask.to(self.device)
        return mask  # [B, T]

    def get_dataloader(self, shuffle=True, t_res=None, t_cutoff=None):
        if self.x_data is None:
            self.generate_data()

        mask = self.create_mask(t_res=t_res, t_cutoff=t_cutoff)  # [B, T]
        dataset = ODESequenceDataset(self.x_data, mask, self.t_scalar)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=shuffle)


class ODESequenceDataset(Dataset):
    def __init__(self, x_data, mask, t_scalar):
        self.x_data = x_data        # [B, T, D]
        self.mask = mask            # [B, T]
        self.t_scalar = t_scalar    # [T] shared across all samples

    def __len__(self):
        return self.x_data.shape[0]

    def __getitem__(self, idx):
        return self.x_data[idx], self.mask[idx], self.t_scalar  # [T, D], [T], [T]


def read_data(folder, plot=False):
    '''
    Default: read .npz files [data of same dim]
    Key: "states" & "time"
    Shape: "states": [num_traj, traj_length, state_dim], "time": [num_traj, traj_length]
    '''

    def normalize(array):
        # -1 to 1
        min_val = array.min()
        max_val = array.max()
        return 2 * (array - min_val) / (max_val - min_val) - 1

    def plot_data(t_data, x_data, num_systems):
        num_traj = x_data.shape[0] // num_systems  # assume each system has same num_traj
        state_dim = x_data.shape[2]
        plt.figure(figsize=(20, 4))

        for d in range(0, state_dim):
            plt.subplot(1, state_dim, d + 1)
            plt.title(f"Dim {d + 1}")

            colors = plt.cm.viridis(np.linspace(0, 1, num_systems))

            for s in range(num_systems):
                for i in range(num_traj):
                    plt.plot(
                        t_data[s * num_traj + i, :],
                        x_data[s * num_traj + i, :, d],
                        linestyle="-",
                        color=colors[s],
                        label=f"Sys {s}" if i == 0 else None,
                    )

            plt.xlabel("Time (s)")
            plt.ylabel("State Value (-1 to 1)")
            plt.grid(True)
            plt.legend(loc="upper right")
            plt.tight_layout()
        plt.show()

    import os
    files = [f for f in os.listdir(folder) if f.endswith(".npz")]

    # number of systems
    num_systems = len(files)

    # read all systems
    min_traj_length = np.inf
    max_time_index = -np.inf
    sys = {}
    for s in range(num_systems):
        data = np.load(folder + "/" + files[s])
        sys[s] = {"states": data["states"], "time": data["time"]}

        # update minimum traj length
        min_traj_length = min(min_traj_length, sys[s]["states"].shape[1])
        # update maximum time index
        max_time_index = max(max_time_index, max(sys[s]["time"][:, -1]))

    for s in range(num_systems):
        # cut all traj to minimum traj length
        sys[s]["states"] = sys[s]["states"][:, 0:min_traj_length, :]
        # make time index between [0,1]
        sys[s]["time"] = sys[s]["time"][:, 0:min_traj_length] / max_time_index

    # connect all systems
    x = np.vstack([sys[s]["states"] for s in range(num_systems)])[:, :200, :]
    t = np.vstack([sys[s]["time"] for s in range(num_systems)])[:, :200]

    # plot
    if plot:
        plot_data(t, x, num_systems)

    return normalize(x), t, num_systems


class TimeSeriesDataset(Dataset):
    def __init__(self, x, mask, t):
        self.x = x            # [B, T, D]
        self.mask = mask      # [B, T, D]
        self.t = t            # [T]

    def __len__(self):
        return self.x.shape[0]  # Number of sequences

    def __getitem__(self, idx):
        return self.x[idx], self.mask[idx], self.t  # [T, D], [T, D], [T]


def read_swing_data(filepath, state_dim=5, max_seq_length=1000, sample_resolution=0.005, use_seq=800, plot=True):
    import math
    import numpy as np
    import matplotlib.pyplot as plt

    def normalize_all(array):
        # -1 to 1
        min_val = array.min()
        max_val = array.max()
        return 2 * (array - min_val) / (max_val - min_val) - 1

    def plot_data(t_data, x_data, num_systems):
        num_traj = x_data.shape[0] // num_systems  # assume each system has same num_traj
        state_dim = x_data.shape[2]
        plt.figure(figsize=(20, 4))

        for d in range(0, state_dim):
            plt.subplot(1, state_dim, d + 1)
            plt.title(f"Dim {d + 1}")

            colors = plt.cm.viridis(np.linspace(0, 1, num_systems))

            for s in range(num_systems):
                for i in range(num_traj):
                    plt.plot(
                        t_data[s * num_traj + i, :],
                        x_data[s * num_traj + i, :, d],
                        linestyle="-",
                        color=colors[s],
                        label=f"Sys {s}" if i == 0 else None,
                    )

            plt.xlabel("Time (s)")
            plt.ylabel("State Value (-1 to 1)")
            plt.grid(True)
            plt.legend(loc="upper right")
            plt.tight_layout()
        plt.show()

    data = np.load(filepath)  # (800, 1000, 20)
    data = data[:use_seq, ...]  # (use_seq, 1000, 20)

    # number of systems
    num_systems = math.ceil(data.shape[2] / state_dim)

    # build t_data
    t_data = np.linspace(sample_resolution, 1, max_seq_length)  # (max_seq_length, )
    t_data = t_data.reshape(1, -1)  # (1, max_seq_length)
    t_data = np.tile(t_data, (num_systems * use_seq, 1))

    # sampling ratio
    sample_ratio = int(sample_resolution / 0.005)  # 0.0005 is min_resolution
    data = data[:, ::sample_ratio, :]

    # match max_seq_length
    if data.shape[1] >= max_seq_length:
        data = data[:, :max_seq_length, :]  # cut
        mask = np.ones_like(data)
    else:
        need = max_seq_length - data.shape[1]
        tail = np.repeat(data[:, -1:, :], need, axis=1)
        data = np.concatenate([data, tail], axis=1)  # pad
        mask = np.ones_like(data)
        mask[:, -need:, :] = 0

    # 1+11. 2+12. etc
    pairs = [np.stack([normalize_all(data[:, :, i]), normalize_all(data[:, :, i + 10])], axis=-1)
             # pairs = [np.stack([data[:, :, i], data[:, :, i+10]], axis=-1)
             for i in range(10)]
    x_data = np.concatenate(pairs, axis=0)
    pairs = [np.stack([mask[:, :, i], mask[:, :, i + 10]], axis=-1)  # (10,1000,2)
             for i in range(10)]
    mask = np.concatenate(pairs, axis=0)  # (100, 1000, 2)

    # plot
    if plot:
        print(f"size of x_data: {x_data.shape}, t_data: {t_data.shape}, mask: {mask.shape}")
        plot_data(t_data, x_data, num_systems)

    return x_data, t_data, num_systems



def read_corey_matlab_data(filepath, state_dim=3, max_seq_length=1000, sample_resolution=0.005, use_seq=800, plot=True):
    import numpy as np
    import matplotlib.pyplot as plt

    def normalize_each_dim(x):
        # -1 to 1
        mins = x.min(axis=(0,1), keepdims=True)
        maxs  = x.max(axis=(0,1), keepdims=True)
        return 2*(x - mins)/(maxs - mins) - 1
    
    def reshape(x, use_seq, max_seq_length, state_dim):
        return (x.reshape(use_seq, max_seq_length, -1, state_dim)  # (use_seq, max_seq_length, num_sys, d)
            .swapaxes(1, 2)                                        # (use_seq, num_sys, max_seq_length, d)
            .reshape(-1, max_seq_length, state_dim))               # (use_seq*num_sys, max_seq_length, d)
    
    def plot_data(t_data, x_data, num_systems):
        num_traj = x_data.shape[0] // num_systems # assume each system has same num_traj 
        state_dim = x_data.shape[2] 
        plt.figure(figsize=(20, 4))

        for d in range(0, state_dim):
            plt.subplot(1, state_dim, d+1)
            plt.title(f"Dim {d+1}")

            colors = plt.cm.viridis(np.linspace(0, 1, num_systems))

            for s in range(num_systems):
                for i in range(num_traj):
                    plt.plot(
                        t_data[s*num_traj+i, :],
                        x_data[s*num_traj+i, :, d],
                        linestyle="-",
                        color=colors[s],
                        label=f"Sys {s}" if i == 0 else None,
                    )

            plt.xlabel("Time (s)")
            plt.ylabel("State Value (-1 to 1)")
            plt.grid(True)
            plt.legend(loc="upper right")
            plt.tight_layout()
        plt.show()
    
    def read_each_system(filepath, state_dim=3, max_seq_length=1000, sample_resolution=0.02, use_seq=800):
        data = normalize_each_dim(np.load(filepath)[:use_seq, ...])

        # sampling ratio
        sample_ratio = int(sample_resolution / 0.02) #0.02 is min_resolution
        data = data[:,::sample_ratio,:]

        # match max_seq_length
        if data.shape[1] >= max_seq_length:
            data = data[:, :max_seq_length, :]                 # cut
            mask = np.ones_like(data)
        else:
            need = max_seq_length - data.shape[1]
            tail = np.repeat(data[:, -1:, :], need, axis=1)
            data = np.concatenate([data, tail], axis=1)        # pad
            mask = np.ones_like(data)
            mask[:, -need:, :] = 0
        
        # build x data
        need = (-data.shape[2]) % state_dim
        if need != 0:
            mask = np.concatenate([mask, np.zeros_like(data[..., -need:])], axis=2)
            data = np.concatenate([data, data[..., -need:]], axis=2)
        # reshape
        x_data = reshape(data, use_seq, max_seq_length, state_dim)
        mask = reshape(mask, use_seq, max_seq_length, state_dim)
        
        return x_data, mask

    sys_files = [
                # continous systems
               'Data-C0ModelRoslerA52', 
               'Data-C1LinearOscillator01',
               'Data-C1LinearOscillator02', 
               'Data-C2LinearFading01', 
               'Data-C2LinearFading02', 
               'Data-C0ModelLorenzA51', 
               'Data-C0ModelRucklideA515', 
               'Data-C0ModelUedaA45', 
            
            #    'Data-C0ModelHalvorsenA513' # has NaN; too complex
            ]

    num_systems = len(sys_files)
    # build t_data
    t_data = np.linspace(sample_resolution, 1, max_seq_length) # (max_seq_length, )
    t_data = t_data.reshape(1, -1)  # (1, max_seq_length)
    t_data = np.tile(t_data, (num_systems*use_seq, 1)) 

    x_data_array = []
    mask_array = []
    for sys in sys_files:
        x_data_this_sys, mask_this_sys = read_each_system(filepath+sys+".npy", state_dim, max_seq_length, sample_resolution, use_seq)
        x_data_array.append(x_data_this_sys)
        mask_array.append(mask_this_sys)

    x_data = np.concatenate(x_data_array, axis=0)
    mask = np.concatenate(mask_array, axis=0)

    # plot
    if plot:
        print(f"size of x_data: {x_data.shape}, t_data: {t_data.shape}, mask: {mask.shape}")
        plot_data(t_data, x_data, num_systems)

    return x_data, t_data, num_systems
