import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
import time
from sklearn import datasets
from sklearn.metrics import accuracy_score

EPSILON = sys.float_info.epsilon

def sigmoid(x):
    """
    Sigmoid function x can be a numpy scalar or array.
    """
    return 1 / (1 + np.exp(-x+(EPSILON*np.ones(x.shape))))

def sigmoid_derivative(x):
    """
    Derivative for Sigmoid function x can be an numpy array or scalar.
    """
    result = sigmoid(x) * (1 - sigmoid(x))
    return result

def get_cost(y_hat, y, type_loss, alpha):
    """
    Returns cost for given type loss.

    Args:
        y_hat (np array): Prediction of the neuron.
        y_train (np array): Required output of the neuron.
        type_loss (str): Type of loss.
        alpha (float): For Lyapunov loss.
    Returns:
        cost (float): Cost for given loss.
    """
    if type_loss == "l1":
        if np.isscalar(y_hat - y):
            cost = np.abs(y_hat - y)
        else:
            cost = np.mean(np.abs(y_hat - y), axis=0)
    if type_loss == "l2":
        if np.isscalar(y_hat - y):
            cost = 0.5*((y_hat - y)**2)
        else:
            cost = 0.5*np.mean((y_hat-y)**2, axis=0)
    if type_loss == "lyapunov":
        if np.isscalar(y_hat - y):
            cost = np.power(np.abs(y_hat-y), alpha+1) / (alpha+1)
        else:
            cost = np.mean(np.power(np.abs(y_hat-y),alpha+1),
                           axis=0)/(alpha+1)
    return cost

def get_gradient_update(y_hat, y_train, x_train, type_loss,
                        alpha):
    """
    Get gradient update for different types of loss function types
    and types of update.

    Args:
        y_hat (np array): Prediction of the neuron.
        y_train (np array): Required output of the neuron.
        x_train (np array): Training data input.
        type_loss (str): Type of Loss.
        alpha (float): Alpha for Lyapunov.
    Returns:
        gradient_update (np array): Gradient update from the method.
    """
    if type_loss == "l1":
        loss_derivative = np.multiply(np.squeeze(np.sign(y_hat-y_train)),
                                      np.squeeze(sigmoid_derivative(y_hat)))
        # Derivative of loss w.r.t weights
        dcostdw = np.dot(x_train.T, loss_derivative)
        gradient_update = np.expand_dims(dcostdw, axis=1)
    if type_loss == "l2":
        loss_derivative = np.multiply(np.squeeze(y_hat-y_train),
                                      np.squeeze(sigmoid_derivative(y_hat)))
        dcostdw = 2*0.5*np.dot(x_train.T, loss_derivative)
        gradient_update = np.expand_dims(dcostdw, axis=1)
    if type_loss == "lyapunov":
        # Loss derivative = \
        #     sign(y_hat - y)|y_hat - y|^alpha (sigmoid_derivative(y_hat))
        loss_derivative =  np.squeeze(np.multiply(np.sign(y_hat-y_train),
                                      np.power(np.abs(y_hat-y_train), alpha)))
        dcostdw = np.dot(x_train.T, loss_derivative)
        dedt = np.multiply(np.power(np.abs(dcostdw), alpha), np.sign(dcostdw))
        gradient_update = np.expand_dims(dedt, axis=1)
    return gradient_update

def single_neuron_training(x_train, y_train, x_test, y_test, w, lr, n_epoch,
                           type_loss, alpha=0.8, convergence_analysis=False):
    """
    Single neuron training function.

    Args:
        x_train (np array): Training input.
        y_train (np array): Training output.
        x_test (np array): Test input.
        y_test (np array): Test output.
        w (np array): Weights array
        lr (float): Learning rate.
        n_epoch (int): Number of epochs.
        type_loss (str): Type of Loss.
        alpha (float): Alpha.
    Returns:
        w (np array): Final weights after training.
        y_hat_test (np array): Final test predictions.
        cumulative_train_time (float): Total training time.
        dynamics_list (list of pd dataframe): Dynamics of loss and weights
    """
    # Initialize train and test list to store epochs and cost
    cost_train_list = []
    cost_test_list = []
    # Initialize train and test list to store time and cost
    cost_time_train_list = []
    cost_time_test_list = []

    
    # Initialize weights with epochs.
    weights_epochs_list = []
    for i in range(w.shape[0]):
        weights_epochs_list.append([])

    # Initialize weights with time.
    weights_time_list = []
    for i in range(w.shape[0]):
        weights_time_list.append([])

    # Cumulative train time initialized
    cumulative_train_time = 0

    # Initialize convergence time
    if convergence_analysis:
        convergence_time = 0
        constant_loss_epoch_count = 0
        condition = (constant_loss_epoch_count<10)
        epoch = 0
    else:
        epoch = 0
        condition = (epoch < n_epoch)

    # Get initial costs before training starts
    y_hat = np.squeeze(sigmoid(np.dot(x_train, w)))
    y_hat_test = np.squeeze(sigmoid(np.dot(x_test,w)))
    cost_train = get_cost(y_hat, y_train, type_loss, alpha)
    cost_test = get_cost(y_hat_test, y_test, type_loss, alpha)
    prev_cost_train = cost_train

    if type_loss == "lyapunov":
        gamma = np.amin(x_train[np.where(x_train>0)])
        k_min = 1
        c = lr * k_min * gamma
        denom = c * (1 - alpha)
        theoretical_time_constraint = np.power(cost_train, 1 - alpha) / denom
        print ("theoretical_time_constraint")
        print (theoretical_time_constraint)

    # Initialize loss and weight dynamics.
    cost_train_list.append([epoch, cost_train])
    cost_test_list.append([epoch, cost_test])
    for i in range(w.shape[0]):
        weights_epochs_list[i].append([epoch, w[i]])
        weights_time_list[i].append([cumulative_train_time, w[i]])

    cost_time_train_list.append([cumulative_train_time, cost_train])
    cost_time_test_list.append([cumulative_train_time, cost_test])

    # Start training.
    while (condition):
        # Start measuring time
        start_train_time = time.clock()
        # Loss derivative
        gradient_update = get_gradient_update(y_hat, y_train, x_train,
                                              type_loss, alpha)
        # Weight update
        w = w - (lr * gradient_update)
        # New y_hat
        y_hat = np.squeeze(sigmoid(np.dot(x_train,w)))
        # Training loss
        cost_train = get_cost(y_hat, y_train, type_loss, alpha)
        # End measuring time
        end_train_time = time.clock()
        # Time difference
        time_diff = end_train_time - start_train_time
        cumulative_train_time += time_diff
        # Populate time list
        cost_train_list.append([epoch+1, cost_train])
        cost_time_train_list.append([cumulative_train_time, cost_train])
        # Test loss
        y_hat_test = np.squeeze(sigmoid(np.dot(x_test,w)))
        cost_test = get_cost(y_hat_test, y_test, type_loss, alpha)
        cost_test_list.append([epoch+1, cost_test])
        cost_time_test_list.append([cumulative_train_time, cost_test])
        # Weight List
        for i in range(w.shape[0]):
            weights_epochs_list[i].append([epoch+1, w[i]])
            weights_time_list[i].append([cumulative_train_time, w[i]])

        if epoch%1000==0:
            print("epoch :{:d} cost_train:{:f}".format(epoch+1, cost_train))
            print("epoch :{:d} cost_test:{:f}".format(epoch+1, cost_test))
        epoch = epoch+1
        if convergence_analysis:
            if np.abs(cost_train - prev_cost_train) < 10e-6:
                constant_loss_epoch_count += 1
            else:
                constant_loss_epoch_count = 0
            condition = (constant_loss_epoch_count < 10)
        else:
            condition = (epoch < n_epoch)
        prev_cost_train = cost_train

    cost_train_list = pd.DataFrame(cost_train_list, columns=['epoch', 'cost'])
    cost_test_list = pd.DataFrame(cost_test_list, columns=['epoch', 'cost'])
    cost_time_train_list = pd.DataFrame(cost_time_train_list,
                                        columns=['time', 'cost'])
    cost_time_test_list = pd.DataFrame(cost_time_test_list,
                                       columns=['time', 'cost'])
    for i in range(w.shape[0]):
        weights_epochs_list[i] = pd.DataFrame(weights_epochs_list[i],
                                              columns=['epoch',
                                              'weight'+str(i+1)])
        weights_time_list[i] = pd.DataFrame(weights_time_list[i],
                                            columns=['time',
                                            'weight'+str(i+1)])
                                 
    dynamics_list = [cost_train_list, cost_test_list, cost_time_train_list,
                     cost_time_test_list, weights_epochs_list,
                     weights_time_list]
    return w, y_hat_test, cumulative_train_time, dynamics_list

def prepare_iris_dataset():
    """ Prepare Iris dataset for training and Test.

    Args: None
    Returns: A 4 element tuple
        X_train: Training input with shape (80, 4) where 80 represents number of
                 training examples and 4 represents 4 features of Iris dataset.
        y_train: Training output with shape (80,) where 80 represents Training
                 examples.
        X_test:  Test input with shape (20, 4), where 20 represents number of
                 test examples.
        y_test:  Test output with shape (20,) where 20 represents Test examples.
    """
    # Load Iris dataset
    iris = datasets.load_iris()
    X = np.array(iris.data[:, :])
    y = iris.target

    # Normalize Input
    min_X = np.tile(np.min(X,axis=0), (X.shape[0],1))
    max_X = np.tile(np.max(X,axis=0), (X.shape[0],1))
    X = (X-min_X)/(max_X-min_X)

    # Train split Iris dataset (80 examples)
    X_train = X[:40, :]
    y_train = y[:40]
    X_train = np.concatenate((X_train, X[50:90, :]), axis=0)
    y_train = np.concatenate((y_train, y[50:90]), axis=0)

    # Test split Iris dataset (20 examples)
    X_test = X[40:50, :]
    y_test = y[40:50]
    X_test = np.concatenate((X_test, X[90:100, :]), axis=0)
    y_test = np.concatenate((y_test, y[90:100]), axis=0)
    return (X_train, y_train, X_test, y_test)

def plot_graphs(list_l1, list_l2, list_lyapunov, y_name, x_name, legend_labels,
                xlabel, ylabel, title, figname, combined=False):
    ymax = max([list_l1[y_name].max(), list_l2[y_name].max(),
               list_lyapunov[y_name].max()])
    xmax = min([list_l1[x_name].max(), list_l2[x_name].max(),
               list_lyapunov[x_name].max()])
    # Plot graphs
    plt.plot(list_l1[x_name], list_l1[y_name], '--k',
             linewidth=2.0, label=legend_labels[0])
    plt.plot(list_l2[x_name], list_l2[y_name], '-.b',
             linewidth=2.0, label=legend_labels[1])
    plt.plot(list_lyapunov[x_name], list_lyapunov[y_name], '-r',
             linewidth=2.0, label=legend_labels[2])
    #plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(True)
    plt.ylim((0, ymax))
    plt.xlim((0, xmax))
    plt.legend(framealpha=1, frameon=True);
    if not combined:
        plt.savefig(figname)

def plot_combined_graph(dynamics_list_l1, dynamics_list_l2,
                        dynamics_list_lyapunov, experiment_type):
    """
    Plot combined graph of Training loss vs time, Testing loss vs time and
    weight1 vs time.

    Args:
        dynamics_list_l1 (list):
        dynamics_list_l2 (list):
        dynamics_list_lyapunov (list):
    """
    plt.rcParams.update({'font.size': 18})
    plt.figure(figsize=(26, 9))
    plt.subplot(131)
    plot_graphs(dynamics_list_l1[2], dynamics_list_l2[2],
                dynamics_list_lyapunov[2], y_name="cost", x_name="time",
                legend_labels=["L1 loss", "L2 loss", "Lyapunov loss"],
                xlabel="Time (in seconds) \n (a)", ylabel="Training Loss",
                title="Training vs time loss plot", figname="loss_train_time.png",
                combined=True)
    plt.subplot(132)
    plot_graphs(dynamics_list_l1[3], dynamics_list_l2[3],
                dynamics_list_lyapunov[3], y_name="cost", x_name="time",
                legend_labels=["L1 loss", "L2 loss", "Lyapunov loss"],
                xlabel="Time (in seconds) \n (b)", ylabel="Test Loss",
                title="Test vs time loss plot", figname="loss_test_time.png",
                combined=True)
    plt.subplot(133)
    plot_graphs(dynamics_list_l1[5][0], dynamics_list_l2[5][0],
                dynamics_list_lyapunov[5][0], y_name="weight1", x_name="time",
                legend_labels=["L1 weight1", "L2 weight1", "Lyapunov weight1"],
                xlabel="Time (in seconds) \n (c)", ylabel="Weight Value",
                title="Weight1 vs time plot",
                figname="weight1_time.png", combined=True)
    plt.savefig('combined_plot_'+str(experiment_type)+'.png')

def plot_individual_graphs(dynamics_list_l1, dynamics_list_l2,
                           dynamics_list_lyapunov, experiment_type):
    """Plot individual graphs.

    Args:
        dynamics_list_l1 (list):
        dynamics_list_l2 (list):
        dynamics_list_lyapunov (list):
    """
    # Plot graphs for training vs epochs
    plt.figure(figsize=(13, 9))
    plot_graphs(dynamics_list_l1[0], dynamics_list_l2[0],
                dynamics_list_lyapunov[0], y_name="cost", x_name="epoch",
                legend_labels=["L1 loss", "L2 loss", "Lyapunov loss"],
                xlabel="Epochs", ylabel="Loss",
                title="Training loss plot", figname="loss_train_epochs_"+
                str(experiment_type)+".png")

    # Plot graphs for testing vs epochs
    plt.figure(figsize=(13, 9))
    plot_graphs(dynamics_list_l1[1], dynamics_list_l2[1],
                dynamics_list_lyapunov[1], y_name="cost", x_name="epoch",
                legend_labels=["L1 loss", "L2 loss", "Lyapunov loss"],
                xlabel="Epochs", ylabel="Loss",
                title="Test loss plot", figname="loss_test_epochs_"+
                str(experiment_type)+".png")

    # Plot graphs for training vs time
    plt.figure(figsize=(13, 9))
    plot_graphs(dynamics_list_l1[2], dynamics_list_l2[2],
                dynamics_list_lyapunov[2], y_name="cost", x_name="time",
                legend_labels=["L1 loss", "L2 loss", "Lyapunov loss"],
                xlabel="Time (in seconds) \n (a)", ylabel="Training Loss",
                title="Training vs time loss plot", figname="loss_train_time_"+
                str(experiment_type)+".png")

    # Plot graphs for testing vs time
    plt.figure(figsize=(13, 9))
    plot_graphs(dynamics_list_l1[3], dynamics_list_l2[3],
                dynamics_list_lyapunov[3], y_name="cost", x_name="time",
                legend_labels=["L1 loss", "L2 loss", "Lyapunov loss"],
                xlabel="Time (in seconds) \n (b)", ylabel="Test Loss",
                title="Test vs time loss plot", figname="loss_test_time_"+
                str(experiment_type)+".png")

    # Plot graphs for weights vs epochs
    for i in range(len(dynamics_list_l1[4])):
        plt.figure(figsize=(13, 9))
        plot_graphs(dynamics_list_l1[4][i], dynamics_list_l2[4][i],
                    dynamics_list_lyapunov[4][i], y_name="weight"+str(i+1),
                    x_name="epoch", legend_labels=["L1 weight"+str(i+1),
                    "L2 weight"+str(i+1), "Lyapunov weight"+str(i+1)],
                    xlabel="Epochs", ylabel="Weight Value",
                    title="Weight"+str(i+1)+" vs epochs plot",
                    figname="weight"+str(i+1)+"_epochs_"+
                    str(experiment_type)+".png")
        plt.figure(figsize=(13, 9))
        plot_graphs(dynamics_list_l1[5][i], dynamics_list_l2[5][i],
                dynamics_list_lyapunov[5][i], y_name="weight"+str(i+1), x_name="time",
                legend_labels=["L1 weight"+str(i+1), "L2 weight"+str(i+1),
                "Lyapunov weight"+str(i+1)],
                xlabel="Time (in seconds) \n (c)", ylabel="Weight Value",
                title="Weight"+str(i+1)+" vs time plot",
                figname="weight"+str(i+1)+"_time_"+str(experiment_type)+".png")

def linear_experiment():
    x = np.array([0.5,0.2])
    y = 0.9
    w = np.zeros((2, 1))
    w[0, 0] = 0.0
    w[1, 0] = 1.0
    lr = 0.1
    alpha = 0.82
    n_epochs = 500
    print ("Training l1")
    w_out_l1, y_pred_l1, train_time_l1, dynamics_list_l1 = \
        single_neuron_training(x, y, x, y, w, lr,
                               n_epochs, "l1")
    print ("Converged value for L1: " + str(sigmoid(np.dot(x,w_out_l1))) +
           " as compared to original label: "+ str(y))
    print ("Training l2")
    w_out_l2, y_pred_l2, train_time_l2, dynamics_list_l2 = \
        single_neuron_training(x, y, x, y, w, lr,
                               n_epochs, "l2")
    print ("Converged value for L2: " + str(sigmoid(np.dot(x, w_out_l2))) +
           " as compared to original label: "+ str(y))

    print ("Training lyapunov")
    (w_out_lyapunov, y_pred_lyapunov, train_time_lyapunov,
       dynamics_list_lyapunov) = \
        single_neuron_training(x, y, x, y, w, lr,
                               n_epochs, "lyapunov", alpha)
    print ("Converged value for Lyapunov: " + str(sigmoid(np.dot(x,
           w_out_lyapunov))) + " as compared to original label: "+ str(y))

    plot_combined_graph(dynamics_list_l1, dynamics_list_l2,
                        dynamics_list_lyapunov, 'linear')
    plot_individual_graphs(dynamics_list_l1, dynamics_list_l2,
                           dynamics_list_lyapunov, 'linear')


def iris_experiment():
    # Define hyperparams
    n_epochs = 600
    lr = 0.01
    alpha = 0.8
    timing_analysis_epochs = 50

    # Prepare Iris dataset
    (X_train, y_train, X_test, y_test) = prepare_iris_dataset()

    # Set seed for reproducibility
    np.random.seed(0)
    w = np.random.random((len(X_train[0]), 1))

    # Run training with l1 and Gradient descent
    print ("Training l1")
    w_out_l1, y_pred_l1, train_time_l1, dynamics_list_l1 = \
        single_neuron_training(X_train, y_train, X_test, y_test, w, lr,
                                   n_epochs, "l1")
    acc_l1 = accuracy_score(y_test, (y_pred_l1>0.5).astype(int))
    print ("Accuracy l1: " + str(acc_l1))

    # Run training with l2 loss and Gradient descent
    print ("Training l2")
    w_out_l2, y_pred_l2, train_time_l2, dynamics_list_l2 = \
        single_neuron_training(X_train, y_train, X_test, y_test, w, lr,
                               n_epochs, "l2")
    acc_l2 = accuracy_score(y_test, (y_pred_l2>0.5).astype(int))
    print ("Accuracy l2: " + str(acc_l2))

    # Run training with Lyapunov loss and Gradient descent with dw/dt
    print ("Training lyapunov")
    (w_out_lyapunov, y_pred_lyapunov, train_time_lyapunov,
       dynamics_list_lyapunov) = \
        single_neuron_training(X_train, y_train, X_test, y_test, w, lr,
                               n_epochs, "lyapunov", alpha)
    acc_lyapunov = accuracy_score(y_test, (y_pred_lyapunov>0.5).astype(int))
    print ("Accuracy Lyapunov: " + str(acc_lyapunov))

    # Do convergence analysis
    convergence_time_l1 = []
    convergence_time_l2 = []
    convergence_time_lyapunov = []
    for i in range(50):
        np.random.seed(i)
        w = np.random.random((len(X_train[0]), 1))
        print ("Convergence for L1")
        _, _, train_time_l1, _ = \
            single_neuron_training(X_train, y_train, X_test, y_test, w, lr,
                                   n_epochs, "l1", convergence_analysis=True)
        print ("Convergence for L2")
        _, _, train_time_l2, _ = \
            single_neuron_training(X_train, y_train, X_test, y_test, w, lr,
                                   n_epochs, "l1", convergence_analysis=True)
        print ("Convergence for Lyapunov")
        _, _, train_time_lyapunov, _ = \
            single_neuron_training(X_train, y_train, X_test, y_test, w, lr,
                                   n_epochs, "lyapunov", alpha,
                                       convergence_analysis=True)
        convergence_time_l1.append(train_time_l1)
        convergence_time_l2.append(train_time_l2)
        convergence_time_lyapunov.append(train_time_lyapunov)
    print ("convergence Analysis Results")
    print ("L1 converges in: " + str(np.mean(convergence_time_l1)) + " seconds")
    print ("L2 converges in: " + str(np.mean(convergence_time_l2)) + " seconds")
    print ("Lyapunov converges in: " + str(np.mean(convergence_time_lyapunov)) +
           " seconds")
    plot_combined_graph(dynamics_list_l1, dynamics_list_l2,
                        dynamics_list_lyapunov, 'iris')
    plot_individual_graphs(dynamics_list_l1, dynamics_list_l2,
                           dynamics_list_lyapunov, 'iris')

def main():
    """Main function that runs the program.
    """
    parser = argparse.ArgumentParser(description='Single Neuron Experiments!')
    parser.add_argument('experiment', type=str, choices=['iris', 'linear'],
                        help='Type of experiment to perform')
    args = parser.parse_args() 
    
    if args.experiment == 'iris':
        iris_experiment()
    if args.experiment == 'linear':
        linear_experiment()     

if __name__ == "__main__":
    main()
