import torch

sigmas = torch.arange(0, 1000) / 1000
shift = 0.6
scale = 2
weights = (sigmas - shift) ** (scale * 2)
weights = weights / weights.sum()

def get_train_tuple(z0=None, z1=None):
    """Generate a training tuple for rectified flow learning with intermediate state and target.

    Args:
        z0: Initial latent vector representing the starting point of the flow.
        z1: Target latent vector representing the end point of the flow.

    Returns:
        Tuple containing the intermediate latent vector (z_t),
        the interpolation parameter (t), and the target vector (target).
    """
    # t = torch.rand((z1.shape[0], 1, 1, 1)).to(z0.device)
    indices = torch.multinomial(weights, num_samples=z1.shape[0], replacement=True)
    t = sigmas[indices].to(z0.device)
    t = t.reshape(-1, 1, 1, 1)
    t = torch.zeros_like(t)
    z_t = t * z1 + (1. - t) * z0
    z_t = z_t.to(dtype=z0.dtype)
    target = z1 - z0

    return z_t, t, target

