# Compute max and min of each column of h_values
max_vals, _ = h_values.max(dim=0, keepdim=True)  # Shape: (1, num_columns)
min_vals, _ = h_values.min(dim=0, keepdim=True)  # Shape: (1, num_columns)

# Normalize h_values based on min and max
normalized_h_values = (h_values - min_vals) / (max_vals - min_vals + 1e-8)  # Add small value to avoid division by zero

# Normalize test and validation heuristics using the same min and max
pos_test_heuristics = (pos_test_heuristics - min_vals) / (max_vals - min_vals + 1e-8)
neg_test_heuristics = (neg_test_heuristics - min_vals) / (max_vals - min_vals + 1e-8)
pos_valid_heuristics = (pos_valid_heuristics - min_vals) / (max_vals - min_vals + 1e-8)
neg_valid_heuristics = (neg_valid_heuristics - min_vals) / (max_vals - min_vals + 1e-8)

# Ensure data is moved to the appropriate device
pos_test_heuristics = pos_test_heuristics.to(device)
neg_test_heuristics = neg_test_heuristics.to(device)
pos_valid_heuristics = pos_valid_heuristics.to(device)
neg_valid_heuristics = neg_valid_heuristics.to(device)

# Optionally print the normalized values for debugging
print("Normalized h_values:", normalized_h_values)
print("Normalized pos_test_heuristics:", pos_test_heuristics)
print("Normalized neg_test_heuristics:", neg_test_heuristics)
print("Normalized pos_valid_heuristics:", pos_valid_heuristics)
print("Normalized neg_valid_heuristics:", neg_valid_heuristics)
