def f3_u_minus_yuan(result_tensor, M, N, f2_z_values_expanded_1, num_intervals=200, chunk_size=10):
    # Extract (x, y, theta, u) from result_tensor
    x = result_tensor[..., 0]
    y = result_tensor[..., 1]
    theta = result_tensor[..., 2]
    u = result_tensor[..., 3]

    # Get device information
    device = result_tensor.device

    # Split the computation into chunks
    num_chunks = (x.shape[0] + chunk_size - 1) // chunk_size
    result = []

    for i in range(num_chunks):
        start = i * chunk_size
        end = min((i + 1) * chunk_size, x.shape[0])

        x_chunk = x[start:end]
        y_chunk = y[start:end]
        theta_chunk = theta[start:end]
        u_chunk = u[start:end]
        f2_chunk = f2_z_values_expanded_1[:, start:end, :, :]

        # Generate t_values from 0 to 1 for intervals
        t_values = torch.linspace(0, 1, num_intervals).view(-1, 1, 1, 1).to(device)

        # Use broadcasting to generate t_values within the range [u_chunk, theta_chunk]
        u_tensor = u_chunk.unsqueeze(0)  # Shape: (1, num_x_chunk, num_y, numpoints_f2_minus_z)
        t_values = u_tensor + (theta_chunk.unsqueeze(0) - u_tensor) * t_values  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)
        print(f"t_values shape:{t_values.shape}")
        # Calculate dt
        dt = (theta_chunk - u_chunk) / num_intervals  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)

        # Calculate exp(c(u)) and exp(-c(t))
        exp_c_u = torch.exp(c(u_chunk)).unsqueeze(0)  # Shape: (1, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Calculate exp(-c(t))
        exp_minus_c_t = torch.exp(-c(t_values))  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Compute the integrand
        integrand = exp_minus_c_t * f2_chunk  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Sum over the t dimension and multiply by dt to get the integral value
        integral_value = torch.sum(integrand, dim=0) * dt.unsqueeze(0)  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)

        # Compute f_u
        f_u = (exp_c_u / a(u_chunk)) * integral_value.squeeze()  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)
        f_u = f_u.squeeze(0)  # Remove the extra dimension

        # Combine the results, adding f_u as the fifth dimension
        result_chunk_extended = torch.cat((result_tensor[start:end], f_u.unsqueeze(-1)), dim=-1)  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z, 5)

        result.append(result_chunk_extended)

    return torch.cat(result, dim=0)
    
    
def f3_u_plus_yuan(result_tensor, M, N, f2_z_values_expanded_1, num_intervals=200, chunk_size=10):
    # Extract (x, y, theta, u) from result_tensor
    x = result_tensor[..., 0]
    y = result_tensor[..., 1]
    theta = result_tensor[..., 2]
    u = result_tensor[..., 3]

    # Get device information
    device = result_tensor.device

    # Split the computation into chunks
    num_chunks = (x.shape[0] + chunk_size - 1) // chunk_size
    result = []

    for i in range(num_chunks):
        start = i * chunk_size
        end = min((i + 1) * chunk_size, x.shape[0])

        x_chunk = x[start:end]
        y_chunk = y[start:end]
        theta_chunk = theta[start:end]
        u_chunk = u[start:end]
        f2_chunk = f2_z_values_expanded_1[:, start:end, :, :]

        # Generate t_values from 0 to 1 for intervals
        t_values = torch.linspace(0, 1, num_intervals, device=device).view(-1, 1, 1, 1)  # Shape: (num_intervals, 1, 1, 1)

        # Use broadcasting to generate t_values within the range [theta_chunk, u_chunk]
        theta_tensor = theta_chunk.unsqueeze(0)  # Shape: (1, num_x_chunk, num_y, numpoints_f2_minus_z)
        t_values = theta_tensor + (u_chunk.unsqueeze(0) - theta_tensor) * t_values  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Calculate dt
        dt = (u_chunk - theta_chunk) / num_intervals  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)

        # Calculate exp(c(u)) and exp(-c(t))
        exp_c_u = torch.exp(c(u_chunk)).unsqueeze(0)  # Shape: (1, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Compute exp(-c(t)) for all t_values
        exp_minus_c_t = torch.exp(-c(t_values))  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Compute the integrand
        integrand = exp_minus_c_t * f2_chunk  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Sum over the t dimension and multiply by dt to get the integral value
        integral_value = torch.sum(integrand, dim=0) * dt.unsqueeze(0)  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)

        # Finally, compute f_u
        f_u = (exp_c_u / a(u_chunk)) * integral_value.squeeze()  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)
        f_u = f_u.squeeze(0)  # Remove the extra dimension

        # Combine the results, adding f_u as the fifth dimension
        result_chunk_extended = torch.cat((result_tensor[start:end], f_u.unsqueeze(-1)), dim=-1)  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z, 5)

        result.append(result_chunk_extended)

    return torch.cat(result, dim=0)
