# Compute A-distance using numpy and sklearn
# Reference: Analysis of representations in domain adaptation, NIPS-07.

import numpy as np
from sklearn import svm
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

def proxy_mlp_a_distance(source_X, target_X, verbose=True):
        X = np.vstack([source_X, target_X])
        y = np.hstack([np.zeros(len(source_X)), np.ones(len(target_X))]) 
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.5, random_state=42)

        if verbose:
            print('MLP on', (X_train.shape, X_val.shape), 'examples')   

        input = X_train.shape[1]
        clf = MLPClassifier(hidden_layer_sizes=(input, 64), max_iter=500, random_state=42)
        clf.fit(X_train, y_train)

        y_val_pred = clf.predict(X_val)
        acc = accuracy_score(y_val, y_val_pred)

        test_risk = 1 - acc
        if test_risk > .5:
            test_risk = 1. - test_risk
        a_distance = 2 * (1 - 2 * test_risk)
        print(f"A-distance: {a_distance}")
        return a_distance

def proxy_a_distance(source_X, target_X, verbose=True):
    """
    Compute the Proxy-A-Distance of a source/target representation
    """
    nb_source = np.shape(source_X)[0]
    nb_target = np.shape(target_X)[0]

    if verbose:
        print('PAD on', (nb_source, nb_target), 'examples')

    C_list = np.logspace(-4, 1, 6) # np.logspace(-5, 4, 10)

    half_source, half_target = int(nb_source/2), int(nb_target/2)
    train_X = np.vstack((source_X[0:half_source, :], target_X[0:half_target, :]))
    train_Y = np.hstack((np.zeros(half_source, dtype=int), np.ones(half_target, dtype=int)))

    test_X = np.vstack((source_X[half_source:, :], target_X[half_target:, :]))
    test_Y = np.hstack((np.zeros(nb_source - half_source, dtype=int), np.ones(nb_target - half_target, dtype=int)))

    best_risk = 1.0
    for C in C_list:
        clf = svm.SVC(C=C, kernel='linear', verbose=False)
        clf.fit(train_X, train_Y)

        train_risk = np.mean(clf.predict(train_X) != train_Y)
        test_risk = np.mean(clf.predict(test_X) != test_Y)

        if verbose:
            print('[ PAD C = %f ] train risk: %f  test risk: %f' % (C, train_risk, test_risk))

        if test_risk > .5:
            test_risk = 1. - test_risk

        best_risk = min(best_risk, test_risk)

    return 2 * (1. - 2 * best_risk)