import torch
import numpy as np
import matplotlib.pyplot as plt


# specify color of different classes
COLORS = ['blue', 'red']
CMAP = 'cool'

# specify resolution of the plots
RESOLUTION = [1000, 1000]


def visualize_dataset_and_decision_boundary(model, ax, dataset, skipping=50, xlim=[-3.5, 3.5], ylim=[-1.5, 1.5]):
    ''' draw dataset and the decision boundary of a specified model '''
    model.eval()
    with torch.no_grad():
        xs = np.linspace(xlim[0], xlim[1], num=RESOLUTION[0])
        ys = np.linspace(ylim[0], ylim[1], num=RESOLUTION[1])

        X, Y = np.meshgrid(xs, ys, indexing='ij')
        batch_inputs = torch.tensor(np.stack((X.flatten(), Y.flatten()), axis=1)).float()
        logits = model(batch_inputs)
        predictions = torch.nn.Softmax()(logits).argmax(dim=1, keepdim=True)
        matrix = predictions.reshape(RESOLUTION).detach().numpy()
        ax.contourf(X, Y, matrix, cmap=CMAP)

        for index in range(0, len(dataset), skipping):
            (x, y), c = dataset[index]
            ax.scatter(x, y, color=COLORS[c])


def draw_line(ax, w, b, color='black', linewidth=1, label=None, xlim=None, **kwargs):
    ''' draw a line given an equation in the form of w * x + b = 0 '''
    w = np.asarray(w, dtype=np.float64)
    b = np.asarray(b, dtype=np.float64)

    assert w.shape[0] == 2, "This function only supports 2D lines."

    if xlim is None:
        xlim = ax.get_xlim()
    x_vals = np.linspace(xlim[0], xlim[1], 200)

    if abs(w[1].item()) > 1e-10:
        # x2 = (-w1*x1 - b) / w2
        y_vals = (-w[0].item() * x_vals - b) / w[1].item()
        ax.plot(x_vals, y_vals, color=color, linewidth=linewidth, label=label, **kwargs)
    else:
        x1 = -b / w[0].item()
        ax.axvline(x=x1, color=color, linewidth=linewidth, label=label, **kwargs)


def draw_vector(ax, origin_x, origin_y, direction_x, direction_y, color='black', width=1, linestyle='-', label=None, **kwargs):
    ''' draw a vector given the coordinates of its origin (tail) and its direction (head - tail) '''
    ax.quiver(origin_x, origin_y, direction_x, direction_y, color=color, angles='xy', scale_units='xy', scale=1, units='xy', width=(width*0.01), linestyle=linestyle, label=label, **kwargs)