import multiprocessing
import warnings
from collections import ChainMap
from functools import partial
from warnings import warn

import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from sklearn.svm import SVC

from symbols.data.data import load_option_partitions
from symbols.data.partitioned_option import PartitionedOption
from symbols.domain.domain import Domain
from symbols.symbols.kde import KernelDensityEstimator
from symbols.symbols.svc import SupportVectorClassifier
from symbols.symbols.svr import SupportVectorRegressor


def learn_effects(env: Domain,
                  partition_dir,
                  view='problem',
                  parallel=False,
                  verbose=False,
                  **kwargs
                  ):
    partitions = load_option_partitions(env.action_space, partition_dir)

    if parallel:
        n_jobs = min(multiprocessing.cpu_count(), env.action_space.n)
        splits = np.array_split(list(range(env.action_space.n)), n_jobs)
        functions = [partial(learn_effects_parallel, partitions, splits[i])
                     for i in range(n_jobs)]
        # run in parallel
        effects = run_parallel(functions)
        return dict(ChainMap(*effects))  # reduce to single dict
    else:

        effects = dict()
        for option in env.action_space:
            if len(partitions[option]) > 0:
                for i, partition in enumerate(partitions[option]):
                    if verbose:
                        print("Calculating effects for option {}, {}:".format(option, i))

                    outcomes = list()
                    probabilities = list()
                    rewards = list()

                    for j, rule in enumerate(partition.rules):
                        if len(rule.mask) == 0:
                            warnings.warn("The mask is empty. Maybe we then learn over s, not s prime?")
                            data = np.array([sample for sample in rule.initiation_set])
                            out_sym = KernelDensityEstimator(mask=range(rule.initiation_set.shape[1]),
                                                             data=data)
                        else:
                            data = np.array([sample for sample in rule.termination_set])
                            out_sym = KernelDensityEstimator(mask=rule.mask, data=data)
                        outcomes.append(out_sym)
                        # TODO: leaving out for now
                        # rew_data = np.array([np.concatenate(sample).ravel() for sample in rule.initiation_set])
                        # rew_sym = SupportVectorRegressor(states=rew_data, reward=rule.rewards)
                        # rewards.append(rew_sym)
                        probabilities.append(rule.probability(partition.states.shape[0]))
                    effects[(option, partition.partition)] = probabilities, rewards, outcomes

        return effects


def learn_effects_parallel(partitions,
                           options,
                           verbose=False):
    effects = dict()
    for option in options:

        if len(partitions[option]) > 0:

            for i, partition in enumerate(partitions[option]):

                if verbose:
                    print("Calculating effects for option {}, {}:".format(option, i))

                outcomes = list()
                probabilities = list()
                rewards = list()

                for j, rule in enumerate(partition.rules):
                    if len(rule.mask) == 0:
                        warnings.warn("The mask is empty. Maybe we then learn over s, not s prime?")
                        data = np.array([sample for sample in rule.initiation_set])
                        out_sym = KernelDensityEstimator(mask=range(rule.initiation_set.shape[1]),
                                                         data=data)
                    else:
                        data = np.array([sample for sample in rule.termination_set])
                        out_sym = KernelDensityEstimator(mask=rule.mask, data=data)
                    outcomes.append(out_sym)
                    # TODO: leaving out for now
                    # rew_data = np.array([np.concatenate(sample).ravel() for sample in rule.initiation_set])
                    # rew_sym = SupportVectorRegressor(states=rew_data, reward=rule.rewards)
                    # rewards.append(rew_sym)
                    probabilities.append(rule.probability(partition.states.shape[0]))
                effects[(option, partition.partition)] = probabilities, rewards, outcomes

    return effects


#  get samples from OTHER partitions of the option as negatives!
def _augment_negative(partition_idx, partitions, negative_samples):
    neg = negative_samples
    for i, p in enumerate(partitions):
        if i == partition_idx:
            continue
        neg = np.concatenate((neg, p.states))
    return neg


def _learn_precondition(partition: PartitionedOption, negative, positive, verbose=False):
    # First calculate the precondition mask using feature selection

    if negative.shape[0] == 0 or positive.shape[0] == 0:
        warn("Need positive and negative samples!")
        return None

    if negative.shape[0] > 0 and positive.shape[0] > 0:
        examples = np.vstack((positive, negative))
    elif negative.shape[0] == 0 and positive.shape[0] == 0:
        raise RuntimeError("No examples whatsoever!")
    elif negative.shape[0] == 0:
        warn("No negative examples! Pressing on...")
        examples = positive
    else:
        warn("No positive examples! Pressing on...")
        examples = negative

    labels = ([1] * len(positive)) + ([0] * len(negative))
    if verbose:
        print("Calculating mask for option {}, partition {} ...".format(partition.option, partition.partition))
    mask = _get_classification_mask(examples, labels, verbose=verbose)
    if verbose:
        print("Precondition mask calculated")
        print("Calculating precondition:")
    try:
        return SupportVectorClassifier(mask, examples, labels)
    except Exception as e:
        print(str(e))
        return None


def _get_classification_mask(examples, labels, improvement_threshold=0, verbose=True):
    # A 3-fold cross-validation score is computed using the support vector machine classifier
    # implementation in scikit-learn, with an RBF kernel, automatic class reweighting,
    # and parameters selected by a grid search with 3-fold cross-validation.
    # We test whether leaving out each state variable independently damaged the score,
    # keeping only variables that did. Finally, we added each state variable back when doing so improved the score.

    mask = []
    n_vars = examples.shape[1]
    all_vars = range(0, n_vars)
    (tot_score, params) = _get_orig_score_params(examples, labels)

    if verbose:
        print(tot_score)

    for m in range(0, n_vars):
        used_vars = list(all_vars[:])
        used_vars.remove(m)
        nscore = _get_subset_score(examples, labels, used_vars, params)

        if nscore < (tot_score - 0.02):
            mask.append(m)

    mxpos = -1
    mxscore = 0.0

    if len(mask) == 0:
        for m in range(0, n_vars):
            score = _get_subset_score(examples, labels, [m], params)
            if score - mxscore > improvement_threshold:
                mxscore = score
                mxpos = m

        mask.append(mxpos)

    msk_score = _get_subset_score(examples, labels, mask, params)
    if verbose:
        print("mask score: " + str(msk_score))
        print(mask)

    for m in range(0, n_vars):
        if m not in mask:
            n_score = _get_subset_score(examples, labels, mask + [m], params)
            if verbose:
                print(str(m) + " : " + str(n_score))
            if n_score - msk_score > improvement_threshold:
                msk_score = n_score
                mask = mask + [m]
                if verbose:
                    print("Adding " + str(m))

            if msk_score == 1:
                break  # can't improve

    return mask


def _get_orig_score_params(examples, labels):
    if len(set(labels)) == 1:
        # everything is in the same class! SVM can't handle :(
        warn("Everything is in the same class! SVM can't handle :(")
        return 1, {'gamma': 5, 'C': 1}

    C_range = np.arange(1, 16, 2)
    gamma_range = np.arange(5, 20)
    param_grid = dict(gamma=gamma_range, C=C_range)
    # param_grid = dict(C=C_range)
    cv = StratifiedKFold(y=labels, n_folds=3)
    grid = GridSearchCV(SVC(class_weight='balanced'), param_grid=param_grid, cv=cv)
    try:
        grid.fit(examples, labels)
    except ValueError:
        return 1, {'gamma': 5, 'C': 1}
    return grid.best_score_, grid.best_params_


def _get_subset_score(examples, labels, used_vars, best_params):
    if len(set(labels)) == 1:
        # everything is in the same class! SVM can't handle :(
        warn("Everything is in the same class! SVM can't handle :(")
        return 1
    examples = examples[:, used_vars]
    if examples.shape[1] == 0:
        return 0
    labels = np.asarray(labels)

    try:
        return np.mean(
            cross_val_score(
                SVC(class_weight='balanced', C=best_params['C'], gamma=best_params['gamma']),
                X=examples, y=labels, cv=3))
    except ValueError:
        return 1


def run_parallel(functions):
    """
    Run the list of function in parallel and return the results in a list
    :param functions: the functions to execute
    :return: a list of results
    """
    n_procs = len(functions)
    pool = multiprocessing.Pool(processes=n_procs)
    processes = [pool.apply_async(functions[i]) for i in range(n_procs)]
    return [p.get() for p in processes]
