import numpy as np
from warnings import warn

import time
from sklearn.model_selection import GridSearchCV, StratifiedKFold, cross_val_score
from sklearn.svm import SVC



def _get_orig_score_params(examples, labels):
    if len(set(labels)) == 1:
        # everything is in the same class! SVM can't handle :(
        warn("Everything is in the same class! SVM can't handle :(")
        return 1, {'gamma': 5, 'C': 1}

    # C_range = np.arange(0.1, 16, 2)
    # gamma_range = np.arange(0, 10)
    C_range = np.logspace(0.1, 0.5, 10)
    gamma_range = np.logspace(0.1, 1, 10)
    param_grid = dict(gamma=gamma_range, C=C_range)
    # param_grid = dict(C=C_range)
    # cv = StratifiedKFold(y=labels, n_folds=3)
    grid = GridSearchCV(SVC(class_weight='balanced'), param_grid=param_grid, cv=3, iid=True)  # 3 fold CV
    try:
        grid.fit(examples, labels)
    except ValueError:
        return 1, {'gamma': 5, 'C': 1}
    return grid.best_score_, grid.best_params_


def _get_subset_score(examples, labels, used_vars, best_params):
    if len(set(labels)) == 1:
        # everything is in the same class! SVM can't handle :(
        warn("Everything is in the same class! SVM can't handle :(")
        return 1
    examples = examples[:, used_vars]
    flattened_samples = np.array([np.concatenate(sample).ravel() for sample in examples])
    if examples.shape[1] == 0:
        return 0
    labels = np.asarray(labels)

    try:
        return np.mean(
            cross_val_score(
                SVC(class_weight='balanced', C=best_params['C'], gamma=best_params['gamma']),
                X=flattened_samples, y=labels, cv=3))
    except ValueError:
        return 1


def range_except(start, end, skip):
    return [x for x in range(start, end) if x != skip]


def get_object_classification_mask(object_indexer, samples, labels, partition_mask, object_ids=None, verbose=True):

    improvement_threshold = 0  # 0.03
    # quick way of getting precondition mask: it is the partition mask, the agent view and potentially the inventory
    # (which we test for)
    if 0 not in partition_mask:
        partition_mask.append(0)

    if len(set(object_ids)) == 1:

        true_idx = object_indexer(object_ids[0])
        if true_idx not in partition_mask:
            partition_mask.append(true_idx)
    else:
        raise ValueError

    mask = sorted(list(set(partition_mask)))
    idx = samples.shape[1] - 1  # TODO: assumes last one is inventory
    if idx not in mask:
        temp = samples[:, mask]
        flattened_samples = np.array([np.concatenate(sample).ravel() for sample in temp])
        mask_score, best_params = _get_orig_score_params(flattened_samples, labels)
        if verbose:
            print("Original score: " + str(mask_score))
            print("Best Params: " + str(best_params))

        # check if adding inventory improves things
        nscore = _get_subset_score(samples, labels, mask + [idx], best_params)
        if nscore - mask_score > improvement_threshold:
            mask.append(idx)
            if verbose:
                print("Adding {} to mask for new score {}".format(idx, nscore))
    mask.sort()
    if verbose:
        print("Final mask: {}".format(mask))
    return mask  # TODO: mask is sorted. Check doesn't cause bugs


    # mask = sorted(list(set(partition_mask)))
    # # return mask
    #
    # # for idx in range(1, samples.shape[1]):
    # #     # idx = samples.shape[1] - 1  # TODO: assumes last one is inventory
    # #     if idx not in mask:
    # #         temp = samples[:, mask]
    # #         flattened_samples = np.array([np.concatenate(sample).ravel() for sample in temp])
    # #         mask_score, best_params = _get_orig_score_params(flattened_samples, labels)
    # #         if verbose:
    # #             print("Original score: " + str(mask_score))
    # #             print("Best Params: " + str(best_params))
    # #
    # #         # check if adding inventory improves things
    # #         nscore = _get_subset_score(samples, labels, mask + [idx], best_params)
    # #         if nscore > mask_score:
    # #             mask.append(idx)
    # #             if verbose:
    # #                 print("Adding {} to mask for new score {}".format(idx, nscore))
    # # mask.sort()
    # # if verbose:
    # #     print("Final mask: {}".format(mask))
    # # return mask  # TODO: mask is sorted. Check doesn't cause bugs
    #
    # if partition_mask is not None:
    #
    #     # quick way of getting precondition mask: it is the partition mask, the agent view and potentially the inventory
    #     # (which we test for)
    #     if 0 not in partition_mask:
    #         partition_mask.append(0)
    #     mask = sorted(list(set(partition_mask)))
    #     # return mask
    #
    #     idx = samples.shape[1] - 1  # TODO: assumes last one is inventory
    #     if idx not in mask:
    #         temp = samples[:, mask]
    #         flattened_samples = np.array([np.concatenate(sample).ravel() for sample in temp])
    #         mask_score, best_params = _get_orig_score_params(flattened_samples, labels)
    #         if verbose:
    #             print("Original score: " + str(mask_score))
    #             print("Best Params: " + str(best_params))
    #
    #         # check if adding inventory improves things
    #         nscore = _get_subset_score(samples, labels, mask + [idx], best_params)
    #         if nscore > mask_score:
    #             mask.append(idx)
    #             if verbose:
    #                 print("Adding {} to mask for new score {}".format(idx, nscore))
    #     mask.sort()
    #     if verbose:
    #         print("Final mask: {}".format(mask))
    #     return mask  # TODO: mask is sorted. Check doesn't cause bugs



        # improvement_threshold = 0.03
    #
    # # Get the mask on an object level i.e. which objects are most important to the classification scheme? Same as
    # # is previous work!
    #
    #
    # #TODO: Need to ignore agent view since that's always in mask!!
    #
    # # get the score for all the data!
    # temp = samples[:, range(1, samples.shape[1])]
    # flattened_samples = np.array([np.concatenate(sample).ravel() for sample in temp])
    # original_score, best_params = _get_orig_score_params(flattened_samples, labels)
    # # original_score = 0.904862579281184
    # # best_params = {'C': 1.2589254117941673, 'gamma': 7.943282347242816}
    #
    # if verbose:
    #     print("Original score: " + str(original_score))
    #     print("Best Params: " + str(best_params))
    #
    # mask = [0] #TODO: Need to ignore agent view since that's always in mask!!
    # # check which objects, when left out, hurt!
    # for m in range(1, samples.shape[1]):
    #     used_vars = range_except(1, samples.shape[1], m) #TODO: Need to ignore agent view since that's always in mask!!
    #     nscore = _get_subset_score(samples, labels, used_vars, best_params)
    #
    #     print("Removing {} got {}".format(m, original_score - nscore))
    #     if original_score - nscore >= improvement_threshold:
    #         mask.append(m)
    #
    #
    # mask_score = _get_subset_score(samples, labels, mask, best_params)
    # if verbose:
    #     print("Score for initial mask {}: {}".format(mask, mask_score))
    #
    # # check which objects, when added back in, improve!
    # for m in range(0, samples.shape[1]):
    #     if m in mask:
    #         continue
    #     nscore = _get_subset_score(samples, labels, mask + [m], best_params)
    #     if nscore - mask_score >= improvement_threshold:
    #         mask.append(m)
    #         if verbose:
    #             print("Adding {} to mask for new score {}".format(m, nscore))
    #         mask_score = nscore
    #     if mask_score == 1:
    #         if verbose:
    #             print("Early break - mask is 1")
    #         break
    # mask.sort()
    # if verbose:
    #     print("Final mask: {}".format(mask))
    # return mask  # TODO: mask is sorted. Check doesn't cause bugs

def get_classification_mask(examples, labels, verbose=True):
    """
    Given a list of samples and label, select the features from the examples that are most instructive
    :param examples:
    :param labels:
    :param verbose:
    :return:
    """

    return np.arange(examples.shape[1])

    if len(set(labels)) == 1:
        warn("Only one class present!")
        return np.arange(examples.shape[1])  # can't do anything with only positive/negative examples

    # model = SelectKBest(f_classif, k=40).fit(examples, labels)
    # return model.get_support()

    svm = SVC(class_weight='balanced', kernel='linear')

    C_range = np.arange(1, 16, 2)
    gamma_range = np.arange(5, 20)
    param_grid = dict(gamma=gamma_range, C=C_range)
    grid = GridSearchCV(svm, param_grid=param_grid, cv=3)

    t = int(round(time.time() * 1000))

    grid.fit(examples, labels)
    svm = grid.best_estimator_
    selector = RFECV(svm, step=1, cv=3, n_jobs=3)
    selector = selector.fit(examples, labels)
    if verbose:
        print("took {} milliseconds".format(int(round(time.time() * 1000)) - t))
    mask = selector.get_support(indices=True)
    if verbose:
        print("Selected mask {}".format(mask))

    return mask