# illustrate_main.py

import os
import time
import pickle
import argparse
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
from cvxnn_utils import *
from nn_utils import *
import matplotlib.patches as patches

def get_parser():
    parser = argparse.ArgumentParser(description='setting')
    parser.add_argument("--dataset", type=str, default='ortho')
    parser.add_argument("--output", type=str, default='pdf')
    parser.add_argument("--lr", type=float, default=1e-1)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--num_points", type=int, default=500)
    parser.add_argument("--show_seed", action='store_true')
    parser.add_argument("--verbose", action='store_true')
    return parser

def rot_mat(theta):
    return np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])

def find_nonzero_neuron(w,thres=1e-6):
    p, d = w.shape
    neurons = []
    for i in range(p):
        wi = w[i,:]
        if np.linalg.norm(wi)>thres:
            neurons.append(wi)
    return neurons

def plotLine(start, end, fig, ax, c='k', linewidth=2, number=100,linestyle='solid'):
    start = start.reshape([2,1])
    end = end.reshape([2,1])
    lines = start-(start-end)*np.arange(number)/(number-1)
    x = lines[0,:].squeeze()
    y = lines[1,:].squeeze()
    ax.plot(x,y,c=c,linewidth=linewidth,linestyle=linestyle)

def main():
    parser = get_parser()
    args = parser.parse_args()
    print(args)

    if args.dataset == 'ortho' or args.dataset == 'spike_free':
        if args.dataset == 'ortho':
            phi = -0.2*np.pi
        elif args.dataset == 'spike_free':
            phi = 0.2*np.pi
        num_points = args.num_points
        a = 2
        b = 1
        theta = np.arange(num_points)/(num_points-1)*2*np.pi
        r_theta = a*b/np.sqrt((b*np.cos(theta))**2+(a*np.sin(theta))**2)
        X_elli = r_theta*np.cos(theta+phi)
        Y_elli = r_theta*np.sin(theta+phi)

        A = rot_mat(phi)@np.diag([a,b])@rot_mat(-phi)
        num_points_small = 10
        theta_small = np.arange(num_points_small)/(num_points_small-1)*2*np.pi
        X_small = np.concatenate([np.cos(theta_small).reshape([-1,1]),np.sin(theta_small).reshape([-1,1])],axis=1)
        AX_small = X_small@A.T

        X = A.copy()

        if args.dataset == 'ortho':
            y = np.array([1,-1])
        elif args.dataset == 'spike_free':
            y = np.array([1,1])

        print('X is printed as follows.')
        print(X)

    dmat = cvx_nn_gen_mask(X)

    # solve the primal problem
    p_star, (w_p, w_m) = cvx_nn_max_margin(X,y,dmat,verbose=args.verbose)
    print('The primal opitmal value is {:.3f}'.format(p_star.item()))

    print('The primal solution is printed as follows.')
    print(w_p)
    print(w_m)

    w_p_nz = find_nonzero_neuron(w_p,thres=1e-6)
    w_m_nz = find_nonzero_neuron(w_m,thres=1e-6)
    w_star_set = w_p_nz+w_m_nz

    w_star_hat_set = []
    Xw_set = []
    for w_star in w_star_set:
        w_star_hat = w_star/np.linalg.norm(w_star)
        w_star_hat_set.append(w_star_hat)
        Xw = relu(X@w_star_hat)
        Xw_set.append(Xw)

    print('The set of hat w_star is printed as follows.')
    print(w_star_hat_set)

    print('The set of relu(X hat w_star) is printed as follows.')
    print(Xw_set)


    # solve the dual problem
    d_star, (lbd, z_p, z_m) = cvx_nn_max_margin_dual(X,y,dmat,verbose=args.verbose)
    print('The dual opitmal value is {:.3f}'.format(d_star.item()))

    print('The dual solution is printed as follows.')
    print(lbd)

    if args.dataset == 'ortho':
        name_append = ''
    elif args.dataset == 'spike_free':
        name_append = '_2'

    if args.show_seed:
        name_append += '_seed_{}'.format(args.seed)

    fig, ax = plt.subplots(1,1,figsize=(8,8))
    ax_sub = ax
    ax_sub.plot(X_elli,Y_elli,c='mediumblue')
    ax_sub.fill(X_elli,Y_elli,c='deepskyblue',alpha=0.5)
    # for Xw in Xw_set:
    #     ax_sub.scatter(Xw[0],Xw[1],60,c='k',marker='*')
    # ax_sub.scatter(AX_small[:,0], AX_small[:,1],40,c='r')
    ax_sub.plot()
    ax_sub.set_xlim([-2.5,2.5])
    ax_sub.set_ylim([-2.5,2.5])
    fig.savefig('illustrations/ellipse{}.{}'.format(name_append,args.output),bbox_inches='tight')

    indices = (X_elli>=0)*(Y_elli>=0)
    X_elli_sub = X_elli[indices==1]
    Y_elli_sub = Y_elli[indices==1]
    idx = np.argmin(X_elli_sub)+1
    idx_set = idx+np.arange(len(X_elli_sub))
    idx_set = idx_set-len(X_elli_sub)*(idx_set>=len(X_elli_sub))
    X_elli_sub = X_elli_sub[idx_set]
    Y_elli_sub = Y_elli_sub[idx_set]

    idx = np.argmin(np.abs(X_elli_sub))
    Y_elli_max = Y_elli_sub[idx]

    idx = np.argmin(np.abs(Y_elli_sub))
    X_elli_max = X_elli_sub[idx]

    X_axisX = np.arange(num_points)/(num_points-1)*X_elli_max
    Y_axisX = np.zeros(num_points)

    X_axisY = np.zeros(num_points)
    Y_axisY = np.arange(num_points)/(num_points-1)*Y_elli_max

    X_concate = np.concatenate([X_elli_sub[1:],X_axisY[::-1],X_axisX])
    Y_concate = np.concatenate([Y_elli_sub[1:],Y_axisY[::-1],Y_axisX])

    X_elli_max_all = np.max(X_elli)
    Y_elli_max_all = np.max(Y_elli)

    X_axisX_full = np.arange(num_points)/(num_points-1)*X_elli_max_all
    Y_axisX_full = np.zeros(num_points)
    X_axisY_full = np.zeros(num_points)
    Y_axisY_full = np.arange(num_points)/(num_points-1)*Y_elli_max_all

    fig, ax = plt.subplots(1,1,figsize=(8,8))
    ax_sub = ax
    ax_sub.plot(X_elli_sub,Y_elli_sub,c='mediumblue')
    ax_sub.plot(X_axisX_full,Y_axisX_full,c='mediumblue')
    ax_sub.plot(X_axisY_full,Y_axisY_full,c='mediumblue')
    ax_sub.fill(X_concate,Y_concate,c='deepskyblue',alpha=0.5)
    ax_sub.set_xlim([-2.5,2.5])
    ax_sub.set_ylim([-2.5,2.5])
    fig.savefig('illustrations/rectified_ellipse{}.{}'.format(name_append, args.output),bbox_inches='tight')

    diag_len = 5
    base = np.arange(num_points)/(num_points-1)
    X_diag = (2*base-1)*diag_len*lbd[0]
    Y_diag = (2*base-1)*diag_len*lbd[1]

    X_y = (2*base-1)*diag_len*y[0]
    Y_y = (2*base-1)*diag_len*y[1]

    X_ytx = (2*base-1)*diag_len*y[1]
    Y_ytx = -(2*base-1)*diag_len*y[0]

    curve_pair_list = [(X_axisX_full, Y_axisX_full), (X_axisY_full, Y_axisY_full), (X_elli_sub, Y_elli_sub)]

    max_points = np.zeros([3,2])
    max_points_proj = np.zeros([3,2])
    min_points = np.zeros([3,2])
    min_points_proj = np.zeros([3,2])

    lbd_hat = lbd/np.linalg.norm(lbd)

    for i, curve_pair in enumerate(curve_pair_list):
        X_points, Y_points = curve_pair
        curve = np.concatenate([X_points.reshape([-1,1]),Y_points.reshape([-1,1])],axis=1)
        idx = np.argmax(curve@lbd_hat)
        max_points[i,:] = curve[idx,:]
        idx = np.argmin(curve@lbd_hat)
        min_points[i,:] = curve[idx,:]

    max_points_proj = (max_points@lbd_hat.reshape([-1,1]))*lbd_hat.reshape([1,-1])
    min_points_proj = (min_points@lbd_hat.reshape([-1,1]))*lbd_hat.reshape([1,-1])

    plt.rcParams.update({'font.size': 14})

    # draw polar set
    fig, ax = plt.subplots(1,1,figsize=(8,8))
    ax_sub = ax

    len_rect = 5
    rect = patches.Rectangle((-1/X_elli_max_all, -1/Y_elli_max_all),2/X_elli_max_all,2/Y_elli_max_all,alpha=0.4,
                             fc='grey',ec='black',linewidth=2,label=r'$\{\lambda:\max_{u:\|u\|_2\leq 1}|\lambda^T(Xu)_+|\leq 1\}$')
    ax_sub.add_patch(rect)
    rect = patches.Rectangle((0, 0),len_rect*y[0],len_rect*y[1],alpha=0.4,
                             fc='mediumblue',ec='blue',linewidth=2,label=r'$\{\lambda:$diag$(y)\lambda\geq 0\}$')
    ax_sub.add_patch(rect)
    ax_sub.set_xlim([-2,2])
    ax_sub.set_ylim([-2,2])
    ax_sub.plot(X_y, Y_y,color='r',label='y',linestyle='dashed')
    ax_sub.scatter(lbd[0],lbd[1],100,color='k',label=r'optimal $\lambda$',marker='*')
    ax_sub.plot(X_diag, Y_diag,c='k',linestyle='dashdot')#,label=r'optimal $\lambda$')
    ax_sub.legend()

    fig.savefig('illustrations/polar{}.{}'.format(name_append,args.output),bbox_inches='tight')

    plt.rcParams.update({'font.size': 14})

    fig, ax = plt.subplots(1,1,figsize=(8,8))
    ax_sub = ax
    ax_sub.plot(X_elli_sub,Y_elli_sub,c='mediumblue')
    ax_sub.plot(X_axisX_full,Y_axisX_full,c='mediumblue')
    ax_sub.plot(X_axisY_full,Y_axisY_full,c='mediumblue')
    ax_sub.fill(X_concate,Y_concate,c='deepskyblue',alpha=0.5)

    alpha=0.7
    for i in range(3):
        ax_sub.scatter(max_points_proj[i,0],max_points_proj[i,1],100,c='r',marker='+',alpha=alpha)
        plotLine(max_points[i,:], max_points_proj[i,:], fig,ax_sub,c='r',linestyle='dotted')
        if i==1:
            ax_sub.scatter(max_points[i,0],max_points[i,1],50,c='r',label='maximal',alpha=alpha)
        else:
            ax_sub.scatter(max_points[i,0],max_points[i,1],50,c='r',alpha=alpha)
        ax_sub.scatter(min_points_proj[i,0],min_points_proj[i,1],100,c='g',marker='+',alpha=alpha)
        plotLine(min_points[i,:], min_points_proj[i,:], fig,ax_sub,c='g',linestyle='dotted')
        if i==1:
            ax_sub.scatter(min_points[i,0],min_points[i,1],50,c='g',label='minimal',alpha=alpha)
        else:
            ax_sub.scatter(min_points[i,0],min_points[i,1],50,c='g',alpha=alpha)



    for e, Xw in enumerate(Xw_set):
        if e==0:
            ax_sub.scatter(Xw[0],Xw[1],80,c='k',marker='*',label=r'optimal $(Xw_{1,i}^*)_+$',alpha=alpha)
        else:
            ax_sub.scatter(Xw[0],Xw[1],80,c='k',marker='*',alpha=alpha)

    ax_sub.plot(X_diag,Y_diag,c='black',linestyle='dashed',label=r'optimal $\lambda$')
    ax_sub.plot(X_ytx,Y_ytx,c='red',linestyle='dashdot',label=r'$y^Tu=0$')
    ax_sub.set_xlim([-1,2])
    ax_sub.set_ylim([-1,2])
    ax_sub.legend()
    fig.savefig('illustrations/minmax_map{}.{}'.format(name_append,args.output),bbox_inches='tight')

    # train network
    torch.manual_seed(args.seed)
    X_t = torch.Tensor(X)
    y_t = torch.Tensor(y).reshape([-1,1])
    break_points = [1,10,100,1000,10000]
    iter_num = 10000
    m = 10
    loss_history, theta_list = train_ReLU_nn(X_t,y_t, m, break_points, num_trial=10, iter_num=iter_num, lr = args.lr, verbose=True)

    plt.rcParams.update({'font.size': 14})

    fig, ax = plt.subplots(1,1,figsize=(8,8))
    ax_sub = ax
    ax_sub.plot(X_elli_sub,Y_elli_sub,c='mediumblue')
    ax_sub.plot(X_axisX_full,Y_axisX_full,c='mediumblue')
    ax_sub.plot(X_axisY_full,Y_axisY_full,c='mediumblue')
    ax_sub.fill(X_concate,Y_concate,c='deepskyblue',alpha=0.5)

    ax_sub.plot(X_diag,Y_diag,c='black',linestyle='dashed',label=r'optimal $\lambda$')
    ax_sub.plot(X_ytx,Y_ytx,c='red',linestyle='dashdot',label=r'$y^Tu=0$')

    num_breakpoints = len(break_points)
    num_neuron_plot = 10
    color_list = ['r','g','orange','y']
    marker_list = ['o','s','d']
    alpha=0.7

    for k in range(num_breakpoints):
        for j in range(num_neuron_plot):
            w1 = theta_list[k][0][:,j]
            w1 = w1/np.linalg.norm(w1)
            Xw1 = relu(X@w1)
            if k==0 and j==0:
                ax_sub.scatter(Xw1[0],Xw1[1],60,c=color_list[j%len(color_list)],
                               marker=marker_list[j%len(marker_list)],
                               label=r'trained $(Xw_{1,i})_+$',alpha=alpha)
            else:
                ax_sub.scatter(Xw1[0],Xw1[1],60,c=color_list[j%len(color_list)],
                               marker=marker_list[j%len(marker_list)],alpha=alpha)

    for k in range(num_breakpoints-1):
        for j in range(num_neuron_plot):
            w1 = theta_list[k][0][:,j]
            w1 = w1/np.linalg.norm(w1)
            Xw1 = relu(X@w1)
            w1_end = theta_list[k+1][0][:,j]
            w1_end = w1_end/np.linalg.norm(w1_end)
            Xw1_end = relu(X@w1_end)
            ax_sub.arrow(Xw1[0],Xw1[1],0.5*(Xw1_end[0]-Xw1[0]), 0.5*(Xw1_end[1]-Xw1[1]),width=0.01,alpha=0.7,
                         color =color_list[j%len(color_list)])
            plotLine(Xw1, Xw1_end, fig,ax_sub,c=color_list[j%len(color_list)],linewidth=2.5)
            
    for e, Xw in enumerate(Xw_set):
        if e==0:
            ax_sub.scatter(Xw[0],Xw[1],200,c='k',marker='*',label=r'optimal $(Xw_{1,i}^*)_+$',alpha=alpha)
        else:
            ax_sub.scatter(Xw[0],Xw[1],200,c='k',marker='*',alpha=alpha)


    ax_sub.set_xlim([-0.2,2])
    ax_sub.set_ylim([-0.2,2])
    ax_sub.legend()
    fig.savefig('illustrations/track_Xw{}.{}'.format(name_append,args.output),bbox_inches='tight')

    plt.rcParams.update({'font.size': 14})

    fig, ax = plt.subplots(1,1,figsize=(8,8))

    ax_sub=ax

    num_points_sector = 100
    len_sector = 5
    for j in range(2):
        Xj = X[j,:]
        r_sector = np.arange(num_points_sector)/(num_points_sector-1)*2-1
        X_sector =  r_sector*Xj[1]*len_sector
        Y_sector = -r_sector*Xj[0]*len_sector
        if j==0:
            ax_sub.plot(X_sector, Y_sector, color='k',linestyle='dashed',label='cone boundary')
        else:
            ax_sub.plot(X_sector, Y_sector, color='k',linestyle='dashed')

    num_breakpoints = len(break_points)
    num_neuron_plot = 10
    color_list = ['r','g','orange','y']
    marker_list = ['o','s','d']
    marker_size = 40
    alpha=0.7

    for k in range(num_breakpoints):
        for j in range(num_neuron_plot):
            w1 = theta_list[k][0][:,j]
            w1 = w1/np.linalg.norm(w1)
            Xw1 = relu(X@w1)
            if k==0:
                ax_sub.scatter(w1[0],w1[1],marker_size,c=color_list[j%len(color_list)],
                               marker=marker_list[j%len(marker_list)],
                               label=r'{}-th neuron'.format(j+1),alpha=alpha)
            else:
                ax_sub.scatter(w1[0],w1[1],marker_size,c=color_list[j%len(color_list)],
                                   marker=marker_list[j%len(marker_list)],alpha=alpha)

    for k in range(num_breakpoints-1):
        for j in range(num_neuron_plot):
            w1 = theta_list[k][0][:,j]
            w1 = w1/np.linalg.norm(w1)
            Xw1 = relu(X@w1)
            w1_end = theta_list[k+1][0][:,j]
            w1_end = w1_end/np.linalg.norm(w1_end)
            Xw1_end = relu(X@w1_end)
            ax_sub.arrow(w1[0],w1[1],(w1_end[0]-w1[0])/2, (w1_end[1]-w1[1])/2,width=0.01,alpha=0.7,
                         color =color_list[j%len(color_list)])
            plotLine(w1, w1_end, fig,ax_sub,c=color_list[j%len(color_list)],linewidth=2.5)
            
        
    for e, w_star_hat in enumerate(w_star_hat_set):
        if e==0:
            ax_sub.scatter(w_star_hat[0],w_star_hat[1],200,c='k',marker='*',label='optimal neuron',alpha=alpha)
        else:
            ax_sub.scatter(w_star_hat[0],w_star_hat[1],200,c='k',marker='*',alpha=alpha)
        
    ax_sub.set_xlim([-1.2,1.2])
    ax_sub.set_ylim([-1.2,1.2])
    ax_sub.legend(loc='center')
    fig.savefig('illustrations/track_w{}.{}'.format(name_append,args.output),bbox_inches='tight')




if __name__ == '__main__':
    main()
