import torch


# here we simplified the problem:
# the decision boundary is not responsible for distinguishing all different categories of classification (not n-m+1 dimensions),
# but only used to distinguish between two classes (the wrongly predicted class given input x_t, and its ground truth class),
# making it a n-2+1 = n-1 dimensional line on the n-dimensional input space
def find_decision_boundary(y_true, y_pred, W_activated, b_activated):
    # consider only the decision boundary of the dimensions between the wrong prediction and its true label
    # equation of decision boundary: W ⋅ x + b = c
    # in binary decision boundary, we have W[first_row] x + b[first_row] = c = W[second_row] x + b[second_row]
    # (W[first_row] - W[second_row]) ⋅ x + (b[first_row] - b[second_row]) = 0
    # or torch.dot(w_vec, x) + b_val = 0

    assert isinstance(y_true, int) and isinstance(y_pred, int), "Invalid value: both \"y_true\" and \"y_pred\" should be integer."

    w_dec_bound = W_activated[y_true] - W_activated[y_pred]
    b_dec_bound = b_activated[y_true] - b_activated[y_pred]

    assert not torch.allclose(w_dec_bound, torch.zeros(size=w_dec_bound.shape)), "Unexpected case: two row vectors of W are the same: {} and {}".format(W_activated[y_true], W_activated[y_pred])

    return w_dec_bound, b_dec_bound


# find the point on the decision boundary that is closest to a given point t
def find_closest_point_on_decision_boundary(t, w_dec_bound, b_dec_bound):
    assert (t is not None) and (w_dec_bound is not None) and (b_dec_bound is not None), "Unexpected error: one or more None input(s) to the function."

    # convert to float64 for a more precise result
    t = t.to(torch.float64)
    w_dec_bound = w_dec_bound.to(torch.float64)
    b_dec_bound = b_dec_bound.to(torch.float64)

    # given the equation of decision bondary: w ⋅ x + b = 0, the cloest point on the decision boundary can be found by t - ((w ⋅ t) + b) * (w / (w ⋅ w))
    x0 = t - (torch.dot(w_dec_bound, t) + b_dec_bound) * (w_dec_bound / torch.dot(w_dec_bound, w_dec_bound))
    
    assert abs(torch.dot(w_dec_bound, x0) + b_dec_bound - 0) < 1e-7, "Unknown error: x0 ({}) not on the decision boundary ({} \xb7 x0 + {}={}).".format(
        x0, w_dec_bound, b_dec_bound, torch.dot(w_dec_bound, x0) + b_dec_bound)

    return x0.to(torch.float32)