# data_generator.py
import numpy as np

def generate_trial_sequence(dataset_config):
    """
    Generates a sequence of trials for one epoch, using the exact logic
    from the original script.

    Args:
        dataset_config (dict): Configuration dictionary containing trial templates and numbers.

    Returns:
        tuple: A tuple containing:
            - x_int (np.ndarray): The full sequence of token IDs for the epoch.
            - trials (np.ndarray): The sequence of trial types (0 or 1).
    """
    num_trials = dataset_config["num_trials"]
    tr_len = dataset_config["trial_length"]
    trial1x = dataset_config["trial1x"]
    trial2x = dataset_config["trial2x"]

    # This loop is an exact copy of the logic in the original script.
    found_trials = False
    while not found_trials:
        trials = np.random.choice(2, num_trials)
        # The condition checks for diversity in the full trial set (excluding the last one).
        if np.sum(trials[:-1] == 1) > 1 and np.sum(trials[:-1] == 0) > 1:
            found_trials = True

    # Create the full integer sequence for the epoch
    x_int = np.zeros(num_trials * tr_len, dtype=np.int64)
    for t in range(len(trials)):
        if trials[t] == 0:
            x_int[t * tr_len: (t + 1) * tr_len] = trial1x
        else:
            x_int[t * tr_len: (t + 1) * tr_len] = trial2x
            
    return x_int, trials