import numpy as np
import scipy.stats as stats


class cubecauchy_gen(stats.rv_continuous):
    def _rvs(self, size=None, random_state=None):
        x = stats.cauchy.rvs(size=size, random_state=random_state)
        return x / ((np.abs(x)) ** (2/3))
    
cubecauchy = cubecauchy_gen(name='cubecauchy')


def get_random_variable(distr):
    if distr == 'gaussian':
        rv = stats.norm()
    elif distr == 'uniform':
        a = 3 ** 0.5
        rv = stats.uniform(loc=-a, scale=2*a)
    elif distr == 'truncated_gaussian':
        a = 2
        # scale = 1 / 0.88
        rv_gauss = stats.norm()
        z = rv_gauss.cdf(a) - rv_gauss.cdf(-a)
        v = 1 - a * rv_gauss.pdf(a) * 2 / z
        scale = 1 / v ** 0.5
        rv = stats.truncnorm(-a, a, scale=scale)
    elif distr == 'laplace':
        rv = stats.laplace(scale=1/2**0.5)
    elif distr == 'cubecauchy':
        rv = cubecauchy(scale=1/2**0.5)
    else:
        raise ValueError("Unknown distribution: '{}'".format(distr))
    return rv


def getCor(cov):
    d = np.diag(cov)**-0.5
    return d[:, None] * cov * d


def seqs2cov(inputs):
    inputs = np.array(inputs)
    inputcov = inputs @ inputs.T / len(inputs[0])
    
    inputidxs = [0]
    for seq in inputs:
        inputidxs.append(inputidxs[-1] + len(seq))
    inputidxs.pop()
    return inputs, inputcov, inputidxs
	
def VStep(cov):
    '''
    Computes E[step(z) step(z)^T | z ~ N(0, `cov`)]
    where step is the function takes positive numbers to 1 and
    all else to 0, and 
    z is a multivariate Gaussian with mean 0 and covariance `cov`
    
    Inputs:
        `cov`: An array where the last 2 dimensions contain covariance matrix of z (and the first dimensions are "batch" dimensions)
    Output:
        a numpy array of the same shape as `cov` that equals the
        expectation above in the last 2 dimensions.
    '''
    ll = list(range(cov.shape[-1]))
    d = np.sqrt(cov[..., ll, ll])
    c = d[..., None]**(-1) * cov * d[..., None, :]**(-1)
    return 2./np.pi * np.arcsin(np.clip(c, -1, 1))

def VErf(cov):
    '''
    Computes E[erf(z) erf(z)^T | z ~ N(0, `cov`)]
    where z is a multivariate Gaussian with mean 0 and covariance `cov`
    
    Inputs:
        `cov`: An array where the last 2 dimensions contain covariance matrix of z (and the first dimensions are "batch" dimensions)
    Output:
        a numpy array of the same shape as `cov` that equals the
        expectation above in the last 2 dimensions.
    '''
    ll = list(range(cov.shape[-1]))
    d = np.sqrt(cov[..., ll, ll] + 0.5)
    
    c = d[..., None]**(-1) * cov * d[..., None, :]**(-1)
    return 2./np.pi * np.arcsin(np.clip(c, -1, 1))

def VDerErf(cov):
    '''
    Computes E[erf'(z) erf'(z)^T | z ~ N(0, `cov`)]
    where erf' is the derivative of erf and
    z is a multivariate Gaussian with mean 0 and covariance `cov`
    
    Inputs:
        `cov`: An array where the last 2 dimensions contain covariance matrix of z (and the first dimensions are "batch" dimensions)
    Output:
        a numpy array of the same shape as `cov` that equals the
        expectation above in the last 2 dimensions.
    '''
    ll = list(range(cov.shape[-1]))
    d = np.sqrt(cov[..., ll, ll])
    dd = 1 + 2 * d
    return 4/np.pi * (dd[..., None] * dd[..., None, :] - 4 * cov**2)**(-1./2)

def J1(c, eps=1e-10):
    c[c > 1-eps] = 1-eps
    c[c < -1+eps] = -1+eps
    return (np.sqrt(1-c**2) + (np.pi - np.arccos(c)) * c) / np.pi

def VReLU(cov, eps=1e-5):
    ll = list(range(cov.shape[-1]))
    d = np.sqrt(cov[..., ll, ll])
    c = d[..., None]**(-1) * cov * d[..., None, :]**(-1)
    return np.nan_to_num(0.5 * d[..., None] * J1(c, eps=eps) * d[..., None, :])

from mpl_toolkits.axes_grid1 import make_axes_locatable
def colorbar(mappable):
    ax = mappable.axes
    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    return fig.colorbar(mappable, cax=cax)