
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from smote_variants import polynom_fit_SMOTE, ProWSyn, SMOTE_IPF

class OversampleResult:
    def __init__(self, x_all, y_all, interpolation):
        self.x_all = x_all
        self.y_all = y_all
        self.interpolation = interpolation

def _border_idx(nbrs_x_min, y):
    # Note - this function assumes that nbrs_x_min[0]=sample that is why there are only k-1 neighbors
    k = len(nbrs_x_min[1])
    flipped_nn_labels = np.array([list(map(lambda idx: 1-y[idx], single)) for single in nbrs_x_min])
    n_maj = np.sum(flipped_nn_labels, axis=1)
    return np.bitwise_and(n_maj >= (k - 1)/2, n_maj < k-1)

def _num_maj_neighbors(nbrs_x_min, y):
    # Note - this function assumes that nbrs_x_min[0]=sample that is why there are only k-1 neighbors
    k = len(nbrs_x_min[1])
    flipped_nn_labels = np.array([list(map(lambda idx: 1-y[idx], single)) for single in nbrs_x_min])
    n_maj = np.sum(flipped_nn_labels, axis=1)
    return n_maj

def _not_outlier_idx(nbrs_x_min, y):
    # Note - this function assumes that nbrs_x_min[0]=sample that is why there are only k-1 neighbors
    k = len(nbrs_x_min[1])
    flipped_nn_labels = np.array([list(map(lambda idx: 1-y[idx], single)) for single in nbrs_x_min])
    n_maj = np.sum(flipped_nn_labels, axis=1)
    return n_maj != k-1

def min_border_idx(x_all, y_all, m_neighbors, knn_algorithm):
    x_min = x_all[y_all == 1]
    nbrs = NearestNeighbors(n_neighbors=m_neighbors + 1, algorithm=knn_algorithm).fit(x_all)
    distances, indices = nbrs.kneighbors(x_min)
    border_idx = _border_idx(indices, y_all)
    return border_idx

def interpolate(x_1, x_2, _lambda):
    return x_1 + (x_2 - x_1) * _lambda

def get_interpolation_couples(x_1_src, x_2_src, x_2_src_labels, knn_indices,
                              min2min_mean, min2min_range,
                              min2maj_mean, min2maj_range,
                              num_couples):
    # format x_1 and x_2
    x_1_src_len = len(x_1_src)
    x_1_src_nbrs = len(knn_indices[0])
    for i in range(num_couples):
        row = i % x_1_src_len
        col = (i // x_1_src_len) % x_1_src_nbrs
        x_1_idx = row
        x_2_idx = knn_indices[row][col]
        a = x_1_src[x_1_idx]
        b = x_2_src[x_2_idx]
        b_type = x_2_src_labels[x_2_idx]
        if i == 0:                                      ## FIXME - write in a more elegant way
            x_1 = a[None, :]
            x_2 = b[None, :]
            if b_type == 1:  # minority
                _lambda = (torch.rand(1) * min2min_range + min2min_mean - min2min_range / 2)[None, :]
            else:  # majority
                _lambda = (torch.rand(1) * min2maj_range + min2maj_mean - min2maj_range / 2)[None, :]
        else:
            x_1 = torch.cat((x_1, a[None, :]), 0)
            x_2 = torch.cat((x_2, b[None, :]), 0)
            if b_type == 1:  # minority
                _lambda = torch.cat((_lambda,(torch.rand(1) * min2min_range + min2min_mean - min2min_range / 2)[None, :]), 0)
            else:  # majority
                _lambda = torch.cat((_lambda,(torch.rand(1) * min2maj_range + min2maj_mean - min2maj_range / 2)[None, :]), 0)
    return (x_1, x_2, _lambda)

def proWSYN_importance(X, y):
    X = X.numpy()
    y = y.numpy()
    X_maj = X[y == 0]
    X_min = X[y == 1]
    P = np.asarray(list(range((len(X_min)))))  #np.where(y == 1)[0]
    Ps = []
    proximity_levels = []
    L = 5
    n_neighbors = 5
    theta = 1
    for i in range(L):
        if len(P) == 0:
            break
        # Step 3 a
        n_neighbors = min([len(P), n_neighbors])
        nn = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=1)
        nn.fit(X_min[P])
        distances, indices = nn.kneighbors(X_maj)
        # Step 3 b
        P_i = np.unique(np.hstack([i for i in indices]))
        # Step 3 c - proximity levels are encoded in the Ps list index
        Ps.append(P[P_i])
        proximity_levels.append(i + 1)
        # Step 3 d
        P = np.delete(P, P_i)
    # Step 4
    if len(P) > 0:
        Ps.append(P)
    # Step 5
    if len(P) > 0:
        proximity_levels.append(i)
        proximity_levels = np.array(proximity_levels)
    # Step 6
    cluster_weights = np.array([np.exp(-theta * (proximity_levels[i] - 1))
                        for i in range(len(proximity_levels))])
    # weights is the probability distribution of sampling in the
    # clusters identified
    cluster_weights = cluster_weights / np.sum(cluster_weights)

    """
    nn = NearestNeighbors(n_neighbors=5, n_jobs=1)
    nn.fit(X)
    distances, indices = nn.kneighbors(X_min)
    not_outlier_idx = _not_outlier_idx(indices, y)
    """

    n_min = len(X_min)
    weights = np.zeros(n_min)
    for cluster_idx, min_idx in enumerate(Ps):
        """
        for idx in min_idx:
            if not_outlier_idx[idx]:
                weights[idx] = cluster_weights[cluster_idx]
            else:
                weights[idx] = 0
        """
        weights[min_idx] = cluster_weights[cluster_idx]
    weights = weights / sum(weights)
    return weights

def get_importance(x_all, y_all, x_min, importance_oversampling):
    if importance_oversampling:
        prob = proWSYN_importance(X=x_all, y=y_all)
    else:
        prob = [1 / int(x_min.shape[0])] * int(x_min.shape[0])  # Uniform
    return prob

def deep_smote_oversample(x: torch.Tensor, x_enc: torch.Tensor, y: torch.Tensor,
                          smote_algo_type,
                          m_neighbors,
                          k_neighbors,
                          knn_algorithm,
                          num_samples,
                          importance_oversampling,
                          importance_classifier=None,
                          ratio=None,
                          importance_tenth=None):

    y_all = y
    x_all = x
    x_min = x[y == 1]
    x_all_enc = x_enc
    x_min_enc = x_enc[y == 1]

    if smote_algo_type == 'polyfit':
        x_all, y_all = polynom_fit_SMOTE().sample(x_all_enc.numpy(), y.numpy())
        interpolation = x_all[x.shape[0]:]
        interpolation = torch.tensor(interpolation)

    elif smote_algo_type == 'prowsyn':
        x_all, y_all = ProWSyn().sample(x_all_enc.numpy(), y.numpy())
        interpolation = x_all[x.shape[0]:]
        interpolation = torch.tensor(interpolation)

    elif smote_algo_type == 'smote-ipf':
        x_all, y_all = SMOTE_IPF().sample(x_all_enc.numpy(), y.numpy())
        interpolation = x_all[x.shape[0]:]
        interpolation = torch.tensor(interpolation)

    elif smote_algo_type == 'orig':
        prob = get_importance(x_all, y_all, x_min, importance_oversampling)
        x_min_enc = x_min_enc.numpy()
        # fitting the model
        nn = NearestNeighbors(n_neighbors=k_neighbors + 1, n_jobs=1)
        nn.fit(x_min_enc)
        dist, ind = nn.kneighbors(x_min_enc)

        # generating samples
        base_indices = np.random.choice(list(range(len(x_min_enc))), num_samples, p=prob)
        neighbor_idx = np.random.choice(list(range(1, k_neighbors + 1)), num_samples)
        neighbor_indices = ind[base_indices, neighbor_idx]

        x_base = x_min_enc[base_indices]
        x_neighbor = x_min_enc[neighbor_indices]
        if ratio:
            ratio = np.ones((num_samples, 1)) * ratio
        else:
            ratio = np.random.rand(num_samples, 1)

        interpolation = x_base + np.multiply(ratio, x_neighbor - x_base)
        interpolation = torch.tensor(interpolation)

    elif smote_algo_type == 'orig_ignore_outlier':
        # exclude outlier points
        nbrs = NearestNeighbors(n_neighbors=m_neighbors + 1, algorithm=knn_algorithm).fit(x_all)
        distances, indices = nbrs.kneighbors(x_min)
        not_outlier_idx = _not_outlier_idx(indices, y)
        x_min_enc = x_min_enc[not_outlier_idx].numpy()
        # fitting the model
        nn = NearestNeighbors(n_neighbors=k_neighbors + 1, n_jobs=1)
        nn.fit(x_min_enc)
        dist, ind = nn.kneighbors(x_min_enc)

        # generating samples
        base_indices = np.random.choice(list(range(len(x_min_enc))), num_samples)
        neighbor_idx = np.random.choice(list(range(1, k_neighbors + 1)), num_samples)
        neighbor_indices = ind[base_indices, neighbor_idx]

        x_base = x_min_enc[base_indices]
        x_neighbor = x_min_enc[neighbor_indices]

        interpolation = x_base + np.multiply(np.random.rand(num_samples, 1), x_neighbor - x_base)
        interpolation = torch.tensor(interpolation)

    else:
        if smote_algo_type == 'vanilla':
            x_1_src = x_min_enc
            x_2_src = x_min_enc
            x_2_src_labels = torch.ones(x_2_src.shape[0]) # only minority
            nbrs = NearestNeighbors(n_neighbors=k_neighbors+1, algorithm=knn_algorithm).fit(x_2_src)
            distances, indices = nbrs.kneighbors(x_1_src)
            knn_indices = [x[1:] for x in indices]   # remove first index which is always identity

        elif smote_algo_type == 'borderline-1': # Probably not effective - see my_borderline below
            # detect minority border points
            nbrs = NearestNeighbors(n_neighbors=m_neighbors + 1, algorithm=knn_algorithm).fit(x_all_enc)
            distances, indices = nbrs.kneighbors(x_min_enc)
            border_idx = _border_idx(indices, y)
            x_1_src = x_min_enc[border_idx]
            # find knn minority points for border points
            x_2_src = x_min_enc
            x_2_src_labels = torch.ones(x_2_src.shape[0])  # only minority
            nbrs = NearestNeighbors(n_neighbors=k_neighbors + 1, algorithm=knn_algorithm).fit(x_2_src)
            distances, indices = nbrs.kneighbors(x_1_src)
            knn_indices = [x[1:] for x in indices]  # remove first index which is always identity

        elif smote_algo_type == 'borderline-2': # Probably not effective - see my_borderline below
            # detect minority border points
            nbrs = NearestNeighbors(n_neighbors=m_neighbors + 1, algorithm=knn_algorithm).fit(x_all_enc)
            distances, indices = nbrs.kneighbors(x_min_enc)
            border_idx = _border_idx(indices, y)
            x_1_src = x_min_enc[border_idx]
            # find knn minority points for border points
            x_2_src = x_all_enc
            x_2_src_labels = y
            nbrs = NearestNeighbors(n_neighbors=k_neighbors + 1, algorithm=knn_algorithm).fit(x_2_src)
            distances, indices = nbrs.kneighbors(x_1_src)
            knn_indices = [x[1:] for x in indices]  # remove first index which is always identity

        elif smote_algo_type == 'borderline-1-detect-border-in-non-latent':
            # detect minority border points
            nbrs = NearestNeighbors(n_neighbors=m_neighbors + 1, algorithm=knn_algorithm).fit(x_all)
            distances, indices = nbrs.kneighbors(x_min)
            border_idx = _border_idx(indices, y)
            x_1_src = x_min_enc[border_idx]
            # find knn minority points for border points
            x_2_src = x_min_enc
            x_2_src_labels = torch.ones(x_2_src.shape[0])  # only minority
            nbrs = NearestNeighbors(n_neighbors=k_neighbors + 1, algorithm=knn_algorithm).fit(x_2_src)
            distances, indices = nbrs.kneighbors(x_1_src)
            knn_indices = [x[1:] for x in indices]  # remove first index which is always identity

        # Don't search for border points and interpolate only on them.
        # Instead interpolate according to nearest-neighbor. if the neighbor is minority 0.5 else 0.25.
        # Reason - because the datasets are small + the encoder separates the classes
        #          the border points may be limited (even a single border point in extreme cases)
        #          in such cases all interpolated points originate from a few points and final results are not optimal
        # Example - see glass4 when mse2=0 mse3=0.2 - only one border point is found
        elif smote_algo_type == 'my_borderline':
            x_1_src = x_min_enc
            x_2_src = x_all_enc
            x_2_src_labels = y
            nbrs = NearestNeighbors(n_neighbors=k_neighbors + 1, algorithm=knn_algorithm).fit(x_2_src)
            distances, indices = nbrs.kneighbors(x_1_src)
            knn_indices = [x[1:] for x in indices]  # remove first index which is always identity

        elif smote_algo_type == 'my_borderline_ignore_outlier':
            # exclude outlier points
            nbrs = NearestNeighbors(n_neighbors=m_neighbors + 1, algorithm=knn_algorithm).fit(x_all)
            distances, indices = nbrs.kneighbors(x_min)
            not_outlier_idx = _not_outlier_idx(indices, y)
            x_1_src = x_min_enc[not_outlier_idx]
            x_2_src = x_all_enc
            x_2_src_labels = y
            nbrs = NearestNeighbors(n_neighbors=k_neighbors + 1, algorithm=knn_algorithm).fit(x_2_src)
            distances, indices = nbrs.kneighbors(x_1_src)
            knn_indices = [x[1:] for x in indices]  # remove first index which is always identity

        else:
            raise Exception("Argument 'algo_type' not supported")

        x_1, x_2, _lambda = get_interpolation_couples(x_1_src, x_2_src, x_2_src_labels, knn_indices,
                                  min2min_mean=0.5, min2min_range=1,
                                  min2maj_mean=0.1, min2maj_range=0.2,
                                  num_couples=num_samples)
        interpolation = interpolate(x_1, x_2, _lambda)
        base_indices = x_1
        neighbor_indices = x_2
        ratio = _lambda

    return interpolation, base_indices, neighbor_indices, ratio