import numpy as np
import torch
import matplotlib.pyplot as plt

def plotter(pts, ppts, wbdict=None, titlestr=None, save_title=None, quiver=True):
        with torch.no_grad():
            w, b = wbdict['w'], wbdict['b']
            wts, bias = wbdict['wts'], wbdict['bias']
            xx = np.linspace(-0.01, 1, 4)
            def svm_line_compute(x_points, w_svm, b_svm):
                return np.array([-1*(w_svm[0]*xp + b_svm)/w_svm[1] for xp in x_points])
             
            yy = svm_line_compute(xx, wts, bias)
            plt.plot(xx, yy, linewidth=0.5, color='blue')
            
            plt.xlim([0, 1])
            plt.ylim([0, 1])
            
            yy = svm_line_compute(xx, w, b)
            plt.plot(xx, yy, linewidth=0.5, color="brown")
            plt.scatter(pts.numpy()[:,0], pts.numpy()[:,1], color='green', marker="o")
            plt.scatter(ppts.numpy()[:,0], ppts.numpy()[:,1], color='red', marker="x")
            plt.legend(["svm line", "data gen line", "orig", "perturbed"]) # , 
            if quiver:
                diff = (ppts - pts).numpy()
                plt.quiver(pts.numpy()[:,0], pts.numpy()[:,1], diff[:,0], diff[:,1], angles='xy', scale_units='xy', scale=1, width=0.001)
            if titlestr is not None:
                plt.title(titlestr)
            # if save_title is not None:
            #     plt.savefig(rho_dir+"/"+save_title+".png")