from itertools import combinations_with_replacement, repeat

import matplotlib.pyplot as plt
import networkx as nx
from sklearn import datasets, metrics, svm
from sklearn.model_selection import train_test_split
from scipy.misc import derivative
import numpy as np
from multiprocessing import Pool, freeze_support
from numpy.polynomial import chebyshev as ch
from itertools import product
from scipy.linalg import sqrtm

def all_paths(G, max_depth, history=None, depth=1, paths=None, graph_search=False):
    if depth > max_depth:
        return []

    if paths is None:
        paths = set()

    if history is None:
        for v in G:
            hist = [v]
            paths.add(tuple(hist))
            all_paths(G, max_depth, hist, depth+1, paths)
    else:
        for v in G[history[-1]]:
            if graph_search and v in history:
                continue
            hist = history + [v]
            paths.add(tuple(sorted(hist)))
            all_paths(G, max_depth, hist, depth+1, paths)

    return list(sorted(paths))

def get_basis(G, deg):
    paths = all_paths(G, deg)

    basis = []
    for path in paths:
        j, d = [], []
        for v in set(path):
            j.append(v[0]*8+v[1])
            d.append(path.count(v))
        basis.append((j, d))
    return basis

def f1(clf, X, y):
    yhat = clf.predict(X)
    return metrics.f1_score(yhat, y, pos_label=None, average='weighted')

def attack_fgsm(args):
    clf, x, y, C = args

    def obj(r, i):
        _x = np.copy(x)
        _x[0, i] = r

        log_proba = clf.decision_function(_x)
        return log_proba[0, y]
    
    x = x[None, :]
    grad = np.array([derivative(lambda r: obj(r, i), x[0, i]) for i in range(x.shape[1])])
    adv_dir = grad / np.linalg.norm(grad, ord=2)

    return np.clip(x - C * adv_dir[None, :], -1, 1)

def vander(X, basis, verbose=True):
    N = X.shape[0]
    M = len(basis)
    V = np.ones((N, M))
    for n, m in product(range(N), range(M)):
        if verbose and n%10 == 0 and m%M == 0: print('\r%1.2f' % (n/N), end='')
        v = ch.chebvander(X[n], deg)
        V[n, m] = np.prod(v[basis[m]])
    print()
    return V

def Itv(bi, bj):
    return np.where(bi == bj, np.where(bi == 0, 1, 1/2), 0)

def Iuv(bi, bj):
    return np.where((bi % 2) == (bj % 2), np.where(bi % 2 == 0, 1 + 2 * np.floor(np.minimum(bi, bj) / 2), 2 * np.ceil(np.minimum(bi, bj) / 2)), 0)

def Ib(bi, bj):
    res = 0
    N = len(bi)
    temp = Itv(bi, bj) ** (1 - np.eye(bi.size))
    res = np.sum(bi * bj * Iuv(bi - 1, bj - 1) * np.prod(temp, axis=-1))
    return res

def sync_mat_time(basis):
    M = len(basis)
    MM = M*(M-1) / 2 + M
    Sigma = np.eye(M)

    counter = 0
    for i, j in combinations_with_replacement(range(M), r=2):
        if counter%100 == 0: print('\r%f' % (counter / MM), end='')
        counter += 1
        bi = np.zeros(64)
        bi[basis[i][0]] = basis[i][1]
        bj = np.zeros(64)
        bj[basis[j][0]] = basis[j][1]
        inner = Ib(bi, bj)
        Sigma[(i, j), (j, i)] = inner
    return Sigma


if __name__ == '__main__':
    freeze_support()

    digits = datasets.load_digits()

    # flatten the images
    n_samples = len(digits.images)
    data = digits.images.reshape((n_samples, -1))

    # Split data into 50% train and 50% test subsets
    X_train, X_test, y_train, y_test = train_test_split(
        data, digits.target, test_size=0.5, shuffle=False)

    X_train = X_train / 8 - 1
    X_test = X_test / 8 - 1

    G = nx.generators.lattice.grid_2d_graph(8, 8)
    for v in G:
        G.add_edge(v, v)

    degrees = [1, 2, 3, 4]

    nat_poly_score = []
    nat_orth_score = []
    nat_sync_score = []

    adv_poly_score = []
    adv_orth_score = []
    adv_sync_score = []

    for deg in degrees:
        basis = get_basis(G, deg)

        # Poly
        clf_poly = svm.SVC(kernel='poly', degree=deg)
        clf_poly.fit(X_train, y_train)
        
        nat_poly_score.append(f1(clf_poly, X_test, y_test))

        with Pool(processes=3) as pool:
            adv_poly = pool.map(attack_fgsm, zip(repeat(clf_poly), X_test, y_test, repeat(1)))
        adv_poly = np.vstack(adv_poly)

        adv_poly_score.append(f1(clf_poly, adv_poly, y_test))
        # Poly

        # Orth
        V_train_orth = vander(X_train, basis)
        V_test_orth = vander(X_test, basis)

        clf_orth = svm.LinearSVC()
        clf_orth.fit(V_train_orth, y_train)

        nat_orth_score.append(f1(clf_orth, V_test_orth, y_test))
        V_orth_adv = vander(adv_poly, basis)
        adv_orth_score.append(f1(clf_orth, V_orth_adv, y_test))
        # Orth

        # Sync
        Sigma = sync_mat_time(basis)
        Lambda = np.real(sqrtm(np.linalg.inv(Sigma)))

        V_train_sync = V_train_orth @ Lambda
        V_test_sync = V_test_orth @ Lambda

        clf_sync = svm.LinearSVC()
        clf_sync.fit(V_train_sync, y_train)

        nat_sync_score.append(f1(clf_sync, V_test_sync, y_test))
        V_sync_adv = V_orth_adv @ Lambda
        adv_sync_score.append(f1(clf_sync, V_sync_adv, y_test))
        # Sync

        print(nat_poly_score, nat_orth_score, nat_sync_score)
        print(adv_poly_score, adv_orth_score, adv_sync_score)
    
    fig, ax = plt.subplots(1, 2)
    ax[0].set_title('Natural')
    ax[0].plot(degrees, nat_poly_score, label='polynomial kernel', marker='.')
    ax[0].plot(degrees, nat_orth_score, label='orth. Cheb.', marker='v')
    ax[0].plot(degrees, nat_sync_score, label='sync. Cheb.', marker='x')
    ax[0].set_ylabel(r'avg. f1-score')
    ax[0].set_xlabel(r'degree')
    ax[0].legend()

    ax[1].set_title('Adversarial')
    ax[1].plot(degrees, adv_poly_score, label='polynomial kernel', marker='.')
    ax[1].plot(degrees, adv_orth_score, label='orth. Cheb.', marker='v')
    ax[1].plot(degrees, adv_sync_score, label='sync. Cheb.', marker='x')
    ax[1].set_xlabel(r'degree')
    ax[1].legend()

    fig.savefig('digits.eps', bbox_inches='tight', dpi=100)