import numpy as np
import pandas as pd
import scipy.stats as st

from scipy.spatial.distance import jaccard
import warnings

from .encoder_decoder import ColumnTransformerEnc


def vector2dict(x, feature_names):
    """
    Given an array ```x''' and an ordered list of feature names, returns a dictionary
    where each entry has the form ```feature_name : value''' for each feature in x.
    :param x: list of values
    :param feature_names: ordered feature list as the x instance
    :return:
    """
    return {k: v for k, v in zip(feature_names, x)}

def neuclidean(x, y):
    return 0.5 * np.var(x - y) / (np.var(x) + np.var(y))

def record2str(x, feature_names, numeric_columns, encdec=None):
    xd = vector2dict(x, feature_names)
    if encdec:
        x_dec = encdec.dec(x)
    s = '{ '
    for att, val in xd.items():
        #print('att ', att)
        #print('att val ', val)
        if att not in numeric_columns and val == 0.0:
            continue
        elif att in numeric_columns:
            s += '%s = %s, ' % (att, val)
        elif encdec is None:
            s += '%s = %s' % (att, val)
        else:
            if isinstance(encdec, ColumnTransformerEnc):
                att_split = att.split('=')
                s += '%s = %s, ' % (att_split[0], att_split[1])

    s = s[:-2] + ' }'
    return s


def multilabel2str(y, class_name):
    mstr = ', '.join([class_name[i] for i in range(len(y)) if y[i] == 1.0])
    return mstr


def multi_dt_predict(X, dt_list):
    nbr_labels = len(dt_list)
    Y = np.zeros((X.shape[0], nbr_labels))
    for l in range(nbr_labels):
        Y[:, l] = dt_list[l].predict(X)
    return Y

def mixed_distance_idx(x, y, idx, ddist=jaccard, cdist=neuclidean):

    dim = len(x)
    xc, xd = x[:idx], x[idx:]
    yc, yd = y[:idx], y[idx:]

    wc = 1.0 * len(xc) / dim
    cd = cdist(xc, yc)

    wd = 1.0 * len(xd) / dim
    dd = ddist(xd, yd)

    return wd * dd + wc * cd
def calculate_feature_values(X, numeric_columns_index, categorical_use_prob=False, continuous_fun_estimation=False,
                             size=1000):

    feature_values = list()
    for i in range(X.shape[1]):
        values = X[:, i]
        unique_values = np.unique(values)
        if len(unique_values) == 1:
            new_values = np.array([unique_values[0]] * size)
        else:
            if i in numeric_columns_index:
                values = values.astype(np.float)
                if continuous_fun_estimation:
                    new_values = get_distr_values(values, size)
                else:  # suppose is gaussian
                    mu = float(np.mean(values))
                    sigma = float(np.std(values))
                    new_values = np.random.normal(mu, sigma, size)
                new_values = np.concatenate((values, new_values), axis=0)
            else:
                if categorical_use_prob:
                    diff_values, counts = np.unique(values, return_counts=True)
                    prob = 1.0 * counts / np.sum(counts)
                    new_values = np.random.choice(diff_values, size=size, p=prob)
                else:  # uniform distribution
                    diff_values = unique_values
                    new_values = diff_values

        feature_values.append(new_values)
    return feature_values


def get_distr_values(x, size=1000):
    nbr_bins = int(np.round(estimate_nbr_bins(x)))
    name, params = best_fit_distribution(x, nbr_bins)
    # print(name, params)
    dist = getattr(st, name)

    arg = params[:-2]
    loc = params[-2]
    scale = params[-1]

    start = dist.ppf(0.01, *arg, loc=loc, scale=scale) if arg else dist.ppf(0.01, loc=loc, scale=scale)
    end = dist.ppf(0.99, *arg, loc=loc, scale=scale) if arg else dist.ppf(0.99, loc=loc, scale=scale)

    distr_values = np.linspace(start, end, size)

    return distr_values


# Distributions to check
DISTRIBUTIONS = [st.uniform, st.exponweib, st.expon, st.expon, st.gamma, st.beta, st.alpha,
                 st.chi, st.chi2, st.laplace, st.lognorm, st.norm, st.powerlaw] #st.dweibull,


def freedman_diaconis(x):
    iqr = np.subtract(*np.percentile(x, [75, 25]))
    n = len(x)
    h = max(2.0 * iqr / n**(1.0/3.0), 1)
    k = np.ceil((np.max(x) - np.min(x))/h)
    return k


def struges(x):
    n = len(x)
    k = np.ceil(np.log2(n)) + 1
    return k


def estimate_nbr_bins(x):
    if len(x) == 1:
        return 1
    k_fd = freedman_diaconis(x) if len(x) > 2 else 1
    k_struges = struges(x)
    if k_fd == float('inf') or np.isnan(k_fd):
        k_fd = np.sqrt(len(x))
    k = max(k_fd, k_struges)
    return k


# Create models from data
def best_fit_distribution(data, bins=200, ax=None):
    """Model data by finding best fit distribution to data"""
    # Get histogram of original data
    y, x = np.histogram(data, bins=bins, density=True)
    x = (x + np.roll(x, -1))[:-1] / 2.0

    # Best holders
    best_distribution = st.norm
    best_params = (0.0, 1.0)
    best_sse = np.inf

    # Estimate distribution parameters from data
    for distribution in DISTRIBUTIONS:

        # Try to fit the distribution
        try:
                #print 'aaa'
            # Ignore warnings from data that can't be fit
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')

                # fit dist to data
                params = distribution.fit(data)

                # Separate parts of parameters
                arg = params[:-2]
                loc = params[-2]
                scale = params[-1]

                # Calculate fitted PDF and error with fit in distribution
                pdf = distribution.pdf(x, loc=loc, scale=scale, *arg)
                sse = np.sum(np.power(y - pdf, 2.0))

                # if axis pass in add to plot
                try:
                    if ax:
                        pd.Series(pdf, x).plot(ax=ax)
                except Exception:
                    pass

                # identify if this distribution is better
                # print distribution.name, sse
                if best_sse > sse > 0:
                    best_distribution = distribution
                    best_params = params
                    best_sse = sse

        except Exception:
            pass

    return best_distribution.name, best_params



def sigmoid(x, x0=0.5, k=10.0, L=1.0):
    """
    A logistic function or logistic curve is a common "S" shape (sigmoid curve

    :param x: value to transform
    :param x0: the x-value of the sigmoid's midpoint
    :param k: the curve's maximum value
    :param L: the steepness of the curve
    :return: sigmoid of x
    """
    return L / (1.0 + np.exp(-k * (x - x0)))


def neuclidean(x, y):
    return 0.5 * np.var(x - y) / (np.var(x) + np.var(y) + 0.0000001)


def nmeandev(x, y):  # normalized mean deviation
    return np.mean(np.abs(x-y)/np.max([np.abs(x), np.abs(y)], axis=0))