import torch
import scipy.stats as stats
import numpy as np

from numpy import median, shape, sqrt
from numpy.random import permutation
from scipy.spatial.distance import pdist, squareform

def data_normalize(data):
    data = stats.zscore(data, ddof=1, axis=0)
    data[np.isnan(data)] = 0.
    return data

def set_median_width(X):
    n = shape(X)[0]
    X = X / np.sqrt(X.shape[-1])
    if n > 1000:
        X = X[permutation(n)[:1000], :]
    dists = squareform(pdist(X, 'euclidean'))
    median_dist = median(dists[dists > 0])
    width = sqrt(2.) * median_dist
    return width


    # use empirical kernel width instead of the median
def set_empirical_width(X):
    n = np.shape(X)[0]
    if n < 200:
        width = 1.2
    elif n < 1200:
        width = 0.7
    else:
        width = 0.4
    length = width / np.sqrt(X.shape[1])
    return length

def reduce_func(K, thresh, need_wx=False):
    n = K.shape[0]
    wx, vx = np.linalg.eigh(0.5 * (K + K.T))
    topkx = int(np.min((400, np.floor(n / 4))))
    idx = np.argsort(-wx)
    wx = wx[idx]
    vx = vx[:, idx]
    wx = wx[0:topkx]
    vx = vx[:, 0:topkx]
    vx = vx[:, wx > wx.max() * thresh]
    wx = wx[wx > wx.max() * thresh]
    # vx = 2 * np.sqrt(n) * vx.dot(np.diag(np.sqrt(wx))) / np.sqrt(wx[0])
    vx = vx.dot(np.diag(np.sqrt(wx)))
    if need_wx:
        # amp = 2 * np.sqrt(n) / np.sqrt(wx[0])
        return vx, 1
    return vx

def totensor(x, y, z=None, device="cpu"):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
        y = torch.from_numpy(y)
        if z is not None:
            z = torch.from_numpy(z)

    x = x.to(device)
    y = y.to(device)
    if z is not None:
        z = z.to(device)
        return x, y, z
    return x, y

def cal_kernel(X, length, Y=None):
    Xsq = (X ** 2).sum(dim=1, keepdim=True)
    if Y is None:
        sqdist = Xsq + Xsq.T - 2*X.mm(X.T)
    else:
        Ysq = (Y ** 2).sum(dim=1, keepdim=True)
        sqdist = Xsq + Ysq.T - 2 * X.mm(Y.T)
    return torch.exp(- 0.5 * sqdist / (length**2))

def Pdist2(x):
    """compute the paired distance between x and y."""
    x_norm = (x ** 2).sum(1).view(-1, 1)
    y = x
    y_norm = x_norm.view(1, -1)
    Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
    Pdist[Pdist<0]=0
    return Pdist

def data_split(x, y, z, training_size = 0.5):
    n = x.shape[0]
    if training_size == 0:
        z_tr = z_te = z
        x_tr = x_te = x
        y_tr = y_te = y
    else:
        training_set_per = training_size
        idx_tr = np.random.choice(n, int(training_set_per*n), replace=False)
        idx_te = np.delete(np.arange(n), idx_tr)
        z_tr = z[idx_tr, :]; x_tr = x[idx_tr, :]; y_tr = y[idx_tr, :]
        z_te = z[idx_te, :]; x_te = x[idx_te, :]; y_te = y[idx_te, :]
    return z_tr, z_te, x_tr, x_te, y_tr, y_te