import numpy as np
import matplotlib.pyplot as plt
from config import config
from train import train
from matplotlib.ticker import MultipleLocator
plt.rc('font', family="Arial")
plt.rcParams['font.size'] = '14'
colors = [plt.get_cmap('Set1')(i) for i in range(9)]
s = 110
lw = 2
# Use nn.MSELoss(reduction='sum') for this experiment

def xor_data(label=None):
    x = np.array([[-4, 0],
                  [0, -3],
                  [2, 0],
                  [0, 1]])
    y = np.array([1, -1, 1, -1])[:, np.newaxis]
    if label is not None:
        x = x[label,:]
        y = y[label,:]
    return {"x": x,
            "y": y}


def xor_vis(W):
    fig, ax = plt.subplots(figsize=(3, 3))

    plt.quiver(np.zeros(W.shape[0]), np.zeros(W.shape[0]), W[:,0], W[:,1], \
               color='k', width=0.002, headaxislength=4, headwidth=14, headlength=15)
    plt.scatter(-4, 0, s=s, linewidths=lw, marker='+', c=colors[3])
    plt.scatter(0, -3, s=s, linewidths=lw, marker='_', c=colors[2])
    plt.scatter(2, 0, s=s, linewidths=lw, marker='+', c=colors[1])
    plt.scatter(0, 1, s=s, linewidths=lw, marker='_', c=colors[0])

    ax.set_xlim((-4.5, 4.5))
    ax.set_ylim((-4.5, 4.5))
    plt.gca().xaxis.set_minor_locator(MultipleLocator(1))
    plt.gca().yaxis.set_minor_locator(MultipleLocator(1))
    plt.grid(which='minor', linestyle='--', linewidth=0.3)
    plt.grid(which='major', linestyle='--', linewidth=0.3)
    ax.set_aspect('equal', adjustable='box')
    plt.xticks([-4, -2, 0, 2, 4])
    plt.yticks([-4, -2, 0, 2, 4])
    plt.tight_layout()
    plt.savefig('xor.pdf')
    plt.show()


def xor_train(args):
    args.relu = 0
    results_relu = train(xor_data(), args)
    xor_vis(results_relu['layer'][-1][0].cpu().detach().numpy())

    plt.figure(figsize=(5.2, 3))
    plt.plot(results_relu['Ls'], c='k', lw=3, label="ReLU")
    Ls_tot = results_relu['Ls'][0]
    args.relu = 1
    args.init = args.init/2
    for label in range(4):
        data = xor_data(label)
        results = train(data, args)
        Ls = results['Ls'] - results['Ls'][0] + Ls_tot
        Ls_tot = Ls_tot - results['Ls'][0] 
        if data["y"][0] > 0:
            cap = "$+{}$".format(data["y"][0])
        else:
            cap = "${}$".format(data["y"][0])
        plt.plot(np.arange(args.epoch), Ls, c=colors[3-label], linestyle='--', lw=2.5, alpha=0.9, label="Lin "+cap)
    
    plt.xlim((0,args.epoch))
    plt.xticks([0, args.epoch//2, args.epoch])
    plt.yticks([0, 0.25, 0.5, 0.75, 1])
    plt.xlabel("$t$")
    plt.ylabel("Loss")
    plt.legend(loc=(1.04, 0.3))
    plt.tight_layout(pad=0.5)
    plt.show()


def ortho_data(label=None):
    x = np.array([[2, 1],
                  [-.5, 1]])
    y = np.array([1, -1])[:, np.newaxis]
    if label is not None:
        x = x[label,:]
        y = y[label,:]
    return {"x": x,
            "y": y}


def ortho_vis(W):
    fig, ax = plt.subplots(figsize=(3, 3))

    ax.quiver(np.zeros(W.shape[0]), np.zeros(W.shape[0]), W[:,0], W[:,1], \
               color='k', width=0.002, headaxislength=4, headwidth=14, headlength=15)
    ax.scatter(2, 1, s=s, linewidths=lw, marker='+', c=colors[1])
    ax.scatter(-.5, 1, s=s, linewidths=lw, marker='_', c=colors[0])

    ax.set_xlim((-2.2, 2.2))
    ax.set_ylim((-2.2, 2.2))
    plt.gca().xaxis.set_minor_locator(MultipleLocator(0.5))
    plt.gca().yaxis.set_minor_locator(MultipleLocator(0.5))
    plt.grid(which='minor', linestyle='--', linewidth=0.3)
    plt.grid(which='major', linestyle='--', linewidth=0.3)
    ax.set_aspect('equal', adjustable='box')
    plt.xticks([-2,-1,0,1,2],)
    plt.yticks([-2,-1,0,1,2])
    plt.tight_layout()
    plt.savefig('ortho_Boursier.pdf')
    plt.show()


def ortho_train(args):
    args.relu = 0
    results_relu = train(ortho_data(), args)
    ortho_vis(results_relu['layer'][-1][0].cpu().detach().numpy())

    plt.figure(figsize=(4, 3))
    plt.plot(results_relu['Ls'], c='k', lw=3, label="ReLU")
    Ls_tot = results_relu['Ls'][0]
    args.relu = 1
    args.init = args.init/np.sqrt(2)
    for label in range(2):
        data = ortho_data(label)
        results = train(data, args)
        Ls = results['Ls'] - results['Ls'][0] + Ls_tot
        Ls_tot = Ls_tot - results['Ls'][0]
        if data["y"][0] > 0:
            cap = "$+{}$".format(data["y"][0])
        else:
            cap = "${}$".format(data["y"][0])
        plt.plot(Ls, c=colors[1-label], linestyle='--', lw=2.5, alpha=0.9, label="Lin "+cap)
    
    plt.xlim((0,args.epoch))
    plt.xticks([0, args.epoch//2, args.epoch])
    plt.yticks([0, 0.5, 1])
    plt.xlabel("$t$")
    plt.ylabel("Loss")
    plt.legend()
    plt.tight_layout(pad=0.5)
    plt.show()


if __name__ == "__main__":
    args = config().parse_args()
    args.reduction = 'sum'
    ortho_train(args)
    xor_train(args)