from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier
from statsmodels.nonparametric.kernel_density import KDEMultivariateConditional, EstimatorSettings
from sklearn.preprocessing import LabelEncoder
from pandas import get_dummies
import numpy as np


def propensity_of_treatment_score(data, covariates, treatment, model='logistic', variable_types=None):
    if model == 'logistic':
        model = LogisticRegression(solver='lbfgs')
        data, covariates = binarize_discrete(data, covariates, variable_types)
        model = model.fit(data[covariates], data[treatment].values.ravel())
        scores = model.predict_proba(data[covariates])[:, 1]
        return scores
    else:
        raise NotImplementedError


def state_propensity_score(data, covariates, treatments, variable_types=None):
    if len(set(covariates).intersection(treatments)) != 0:
        raise Exception("Can't control for causal states. Remove treatment from covariates.")
    log_propensities = {}
    for i, treatment in enumerate(treatments):
        if variable_types[treatment] in ['b']:
            log_propensities[treatment] = np.log(binary_treatment_model(data.copy(),
                                                                        covariates + treatments[i+1:],
                                                                        treatment,
                                                                        variable_types))
        elif variable_types[treatment] in ['o', 'u', 'd']:
            log_propensities[treatment] = np.log(categorical_treatment_model(data.copy(),
                                                                             covariates + treatments[i+1:],
                                                                             treatment,
                                                                             variable_types))
        elif variable_types[treatment] in ['c']:
            logging.warn("Weighting for continuous causal states is still experimental! Inferences may be invalid.")
            log_propensities[treatment] = np.log(continuous_treatment_model(data.copy(),
                                                                            covariates + treatments[i+1:],
                                                                            treatment,
                                                                            variable_types))
        else:
            raise Exception("Variable type {} for variable {} is not a recognized format type.".format(variable_types[treatment],
                                                                                                       treatment))
    scores = np.zeros(len(data))
    for treatment in treatments:
        scores += log_propensities[treatment]
    return np.exp(scores)


def binary_treatment_model(data, covariates, treatment, variable_types):
    data, covariates = binarize_discrete(data, covariates, variable_types)
    model = LogisticRegression(solver='lbfgs')
    model = model.fit(data[covariates], data[treatment])
    scores = model.predict_proba(data[covariates])
    scores = scores[range(len(scores)), data[treatment].values.astype(int)]
    return scores


def categorical_treatment_model(data, covariates, treatment, variable_types):
    data, covariates = binarize_discrete(data, covariates, variable_types)
    model = LogisticRegression(multi_class='auto', solver='lbfgs')
    data[treatment], encoder = discrete_to_integer(data[treatment])
    model = model.fit(data[covariates], data[treatment])
    scores = model.predict_proba(data[covariates])
    scores = scores[range(len(data)), data[treatment].values.astype(int)]
    return scores


def continuous_treatment_model(data, covariates, treatment, variable_types):
    data, covariates = binarize_discrete(data, covariates, variable_types)
    if len(data) > 300 or len(treatment + covariates) >= 3:
        defaults = EstimatorSettings(n_jobs=4, efficient=True)
    else:
        defaults = EstimatorSettings(n_jobs=-1, efficient=False)

    if 'c' not in variable_types.values():
        bw = 'cv_ml'
    else:
        bw = 'normal_reference'

    indep_type = get_type_string(covariates, variable_types)
    dep_type = get_type_string([treatment], variable_types)

    model = KDEMultivariateConditional(endog=data[treatment],
                                       exog=data[covariates],
                                       dep_type=''.join(dep_type),
                                       indep_type=''.join(indep_type),
                                       bw=bw,
                                       defaults=defaults)
    scores = model.pdf(endog_predict=data[treatment], exog_predict=data[covariates])
    return scores


def get_type_string(variables, variable_types):
    var_types = []
    for variable in variables:
        if variable_types[variable] in ['b', 'd', 'o', 'u']:
            if variable_types[variable] in ['o', 'u']:
                var_types.append(variable_types[variable])
            else:
                var_types.append('u')
        elif variable_types[variable] in ['c']:
            var_types.append('c')
        else:
            raise Exception("Variable type {} for variable {} not a recognized type.".format(variable_types[variable],
                                                                                             variable))
    return "".join(var_types)


def binarize_discrete(data, covariates, variable_types):
    to_remove = []
    if variable_types:
        for variable in covariates:
            variable_type = variable_types[variable]
            if variable_type in ['d', 'o', 'u']:
                dummies = get_dummies(data[variable])
                dummies.columns = [variable + str(col) for col in dummies.columns]
                dummies = dummies[dummies.columns[:-1]]
                covariates += list(dummies.columns)
                for var_name in dummies.columns:
                    variable_types[var_name] = 'b'
                data[dummies.columns] = dummies
                to_remove.append(variable)
    for variable in to_remove:
        covariates.remove(variable)
        del data[variable]
    return data, covariates


def discrete_to_integer(discrete):
    encoder = LabelEncoder()
    discrete = encoder.fit_transform(discrete)
    return discrete, encoder
