# Deep Embedded Validation
import numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split

def check_normalization(normalization):
    if normalization not in [None, "max", "standardize"]:
        raise ValueError("normalization must be one of [None, 'max', 'standardize']")


def normalize_weights(weights, normalization):
    check_normalization(normalization)
    if normalization == "max":
        weights /= np.max(weights)  # normalize between 0 and 1
        weights -= np.mean(weights) - 1  # shift to have mean of 1
    elif normalization == "standardize":
        weights = (weights - np.mean(weights)) / np.std(weights)  # standardize
        weights += 1  # shift to have mean of 1
    return weights

def get_dev_risk(weight, error, normalization):
    """
    :param weight: shape [N, 1], the importance weight for N source samples in the validation set
    :param error: shape [N, 1], the error value for each source sample in the validation set
    (typically 0 for correct classification and 1 for wrong classification)
    """
    N, d = weight.shape
    _N, _d = error.shape
    assert N == _N and d == _d, 'dimension mismatch!'
    weight = normalize_weights(weight, normalization)
    weighted_error = weight * error
    cov = np.cov(np.concatenate((weighted_error, weight), axis=1),rowvar=False)[0][1]
    var_w = np.var(weight, ddof=1)
    eta = - cov / var_w
    return np.mean(weighted_error) + eta * np.mean(weight) - eta


def get_weight(source_feature, target_feature, validation_feature):
    """
    :param source_feature: shape [N_tr, d], features from source set
    :param target_feature: shape [N_te, d], features from target set
    :param validation_feature: shape [N_v, d], features from source validation set
    :return:
    """
    N_s, d = source_feature.shape
    N_t, _d = target_feature.shape
    if N_s > 10000 or N_t > 10000:
        k = int(max(N_s, N_t)/ 10000)
        source_feature = source_feature[::k]
        target_feature = target_feature[::k]
        N_s, d = source_feature.shape
        N_t, _d = target_feature.shape
    all_feature = np.concatenate((source_feature, target_feature))
    all_label = np.asarray([1] * N_s + [0] * N_t,dtype=np.int32)
    
    decays = [1e-3, 1e-4, 1e-5]
    val_acc = []
    domain_classifiers = []
    
    for decay in decays:
        domain_classifier = MLPClassifier(hidden_layer_sizes=(d,), max_iter=100, alpha=decay, early_stopping=True)
        domain_classifier.fit(all_feature, all_label)
        acc = domain_classifier.score(all_feature, all_label)
        val_acc.append(acc)
        domain_classifiers.append(domain_classifier)
        
    index = val_acc.index(max(val_acc))

    domain_classifier = domain_classifiers[index]

    domain_out = domain_classifier.predict_proba(validation_feature)
    return domain_out[:,:1] / domain_out[:,1:] * N_s * 1.0 / N_t    


def compute_dev(source_dataset, target_dataset, normalization=None):
    source_features = source_dataset[0].numpy()
    source_preds = np.argmax(source_dataset[1].numpy(), axis=1)
    source_labels = source_dataset[2].numpy()
    target_features = target_dataset[0].numpy()
    source_train, source_test, _, source_preds_test, _, source_label_test = \
        train_test_split(source_features, source_preds, source_labels, test_size=0.3)
    weights = get_weight(source_train, target_features, source_test)
    errors = 1-np.equal(source_preds_test, source_label_test).astype(np.float32)
    errors = np.expand_dims(errors, axis=1)
    metric_val = - get_dev_risk(weights, errors, normalization)
    weight_val = - np.mean(weights * errors)

    return metric_val, weight_val