import pickle
import numpy as np
import random
from sklearn.utils import shuffle
from sklearn import preprocessing
from sklearn.datasets import load_iris
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import sys
sns.set_style('whitegrid')

eta = 0.05
T = 110
epsilon = [0.5]
G = 1
C = 4
L = 1
recordflag = 10

t_count = []
for t in range(T):
    if t % recordflag == 0:
        t_count.append(t)

experiments = 5

def sigmoid(N, data, weight):
    x = np.dot(data, weight)
    result = np.zeros(N, )
    for i in range(N):
        if x[i] >= 0:
            result[i] = 1. / (1 + np.exp(-x[i]))
        else:
            result[i] = np.exp(x[i]) / (1 + np.exp(x[i]))
    return result

def sigmoid2(data, weight):
    x = np.dot(data, weight)
    if x >= 0:
        return 1. / (1 + np.exp(-x))
    else:
        return np.exp(x) / (1 + np.exp(x))

def classifier(x, weights):
    prob = sigmoid2(x, weights)
    if prob > 0.5:
        return 1.0
    else:
        return 0.0

def cost(N, prob, labels):
    a = labels * np.log(prob + 1e-10)
    b = (1 - labels) * np.log(1 - prob + 1e-10)
    c = a + b
    return -1 / N * np.sum(c)


def convergence(argv):
    if argv == 'Iris':
        data_iris = load_iris()
        pd.DataFrame(data=data_iris.data, columns=data_iris.feature_names)
        X = data_iris.data
        y = data_iris.target

        for id in range(y.shape[0]):
            if int(y[id]) == 2:
                y[id] = 1

        train_data_num = 120

        start_flag = 1

    elif argv == 'BC':
        X, y = pickle.load(open('BreastCancer_data.p', 'rb'))

        for id in range(y.shape[0]):
            if int(y[id]) == 2:
                y[id] = 1
            else:
                y[id] = 0

        train_data_num = 600

        start_flag = 2

    elif argv == 'CC':
        X, y = pickle.load(open('CreditCard_data.p', 'rb'))

        train_data_num = 800

        start_flag = 3


    elif argv == 'Bank':
        X, y = pickle.load(open('bank_data.p', 'rb'))

        train_data_num = 25000

        start_flag = 4

    elif argv == 'Adult':
        X, y = pickle.load(open('adult_data.p', 'rb'))

        for id in range(y.shape[0]):
            if y[id] == -1:
                y[id] = 0

        train_data_num = 30162

        start_flag = 5


    X, y = shuffle(X, y, random_state=0)
    X = preprocessing.scale(X)

    train_X = X[0:train_data_num]
    train_y = y[0:train_data_num]
    N, d = train_X.shape[0], train_X.shape[1]
    w_init = np.ones(d)
    delta = 1 / N

    test_X = X[train_data_num:]
    test_y = y[train_data_num:]
    test_N = test_X.shape[0]

    # Traditional Gradient Perturbation
    random.seed(0)
    np.random.seed(0)
    print('Traditional Gradient Perturbation......')
    n = N
    loss_tra_mean = np.zeros((experiments, int(T / recordflag)))
    for e in epsilon:
        std = C * G * (T ** 0.5) * (np.log(1 / delta) ** 0.5) / (n * e)
        for experiment in range(experiments):
            w = w_init
            loss_tra = []
            for t in range(T):
                h = sigmoid(n, train_X, w)
                error = h - train_y
                error = error / n
                noise = np.random.normal(loc=0.0, scale=std, size=d)
                gradient = np.matmul(train_X.transpose(), error) + L * w

                w = w - eta * (gradient + noise)

                if t % recordflag == 0:
                    pre = sigmoid(n, train_X, w)
                    loss = cost(n, pre, train_y) + (L / 2) * np.dot(w, w)
                    loss_tra.append(loss)

            loss_tra_mean[experiment] = loss_tra

        loss_tra_min = np.min(loss_tra_mean, axis=0)
        loss_tra_max = np.max(loss_tra_mean, axis=0)
        loss_tra_ave = np.average(loss_tra_mean, axis=0)
    print('################################')

    # Normalized Gradient Perturbation
    random.seed(0)
    np.random.seed(0)
    print('Normalized Gradient Perturbation......')
    n = N
    loss_nor_mean = np.zeros((experiments, int(T / recordflag)))
    for e in epsilon:
        std = C * G * (T ** 0.5) * (np.log(1 / delta) ** 0.5) / (n * e)
        for experiment in range(experiments):
            w = w_init
            loss_nor = []
            count_less1 = 0
            for t in range(T):
                h = sigmoid(n, train_X, w)
                error = h - train_y
                error = error / n
                noise = np.random.normal(loc=0.0, scale=std, size=d)
                gradient = np.matmul(train_X.transpose(), error) + L * w
                gradient_l2 = np.linalg.norm(gradient, ord=2)
                if gradient_l2 < 1:
                    gradient = gradient / gradient_l2
                    count_less1 += 1

                w = w - eta * (gradient + noise)

                if t % recordflag == 0:
                    pre = sigmoid(n, train_X, w)
                    loss = cost(n, pre, train_y) + (L / 2) * np.dot(w, w)
                    loss_nor.append(loss)

            loss_nor_mean[experiment] = loss_nor

            loss_nor_min = np.min(loss_nor_mean, axis=0)
            loss_nor_max = np.max(loss_nor_mean, axis=0)
            loss_nor_ave = np.average(loss_nor_mean, axis=0)
    print('################################')

    return start_flag, loss_nor_ave, loss_nor_min, loss_nor_max, loss_tra_ave, loss_tra_min, loss_tra_max


if __name__ == "__main__":
    start_flag, loss_nor_ave, loss_nor_min, loss_nor_max, loss_tra_ave, loss_tra_min, loss_tra_max = convergence(sys.argv[1])
    # convergence(sys.argv[1])


    color = cm.viridis(0.7)
    color2 = cm.viridis(0.5)
    f, ax = plt.subplots(1, 1)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)

    ax.plot(t_count[start_flag:], loss_tra_ave[start_flag:], color=color, label='TGP')
    ax.fill_between(t_count[start_flag:], loss_tra_min[start_flag:], loss_tra_max[start_flag:], color=color, alpha=0.2)
    ax.plot(t_count[start_flag:], loss_nor_ave[start_flag:], color=color2, label='m-NGP')
    ax.fill_between(t_count[start_flag:], loss_nor_min[start_flag:], loss_nor_max[start_flag:], color=color2, alpha=0.2)

    ax.legend(fontsize=20)
    ax.set_xlabel('Iterations', fontsize=20)
    ax.set_ylabel('Loss', fontsize=20)
    plt.margins(x=0)
    plt.show()