import numpy as np
import torch
import torch.nn

mse_loss = torch.nn.MSELoss(reduction='elementwise_mean')
np.set_printoptions(precision=4)
np.set_printoptions(suppress=True)
def l0_loss(x, bandwidth=0.01):
    """Approximation to L_0 loss using a notch filter (flipped Gaussian).
    """
    notch_filt = lambda x: 1-torch.exp(-x**2 / bandwidth)
    loss = torch.mean(torch.sum(notch_filt(x), dim=1))
    return loss
def lq_loss(x, bandwidth=0.01):
    """Approximation to L_q loss with q < 1, by linearly combining a notch
    filter (flipped Gaussian) with |x|.
    """
    notch_filt = lambda x: 1-torch.exp(-x**2 / bandwidth)
    lq_filt = lambda x: (notch_filt(x) + torch.abs(x)) / 2
    loss = torch.mean(torch.sum(lq_filt(x), dim=1))
    return loss
l1_loss = lambda x: torch.mean( torch.norm(x, p=1, dim=1) )
l2_loss = lambda x: torch.mean( torch.norm(x, p=2, dim=1) )
norm_loss = {
    'l0': l0_loss,
    'lq': lq_loss,
    'l1': l1_loss,
    'l2': l2_loss,
}

def eye_like(tensor):
    return torch.eye(*tensor.size(), out=torch.empty_like(tensor))

def corrcoef(X):
    """Compute Pearson correlation matrix from samples of a random vector

    Parameters
    ----------
    X : 2D torch.Tensor
        N x d matrix consisting of N samples of a d-element vector

    Returns
    ----------
    float
        The (d x d) correlation matrix of the columns of X
    """
    assert type(X) is torch.Tensor
    assert X.dim() == 2
    n, m = X.size()
    mu = torch.mean(X, dim=0)
    std = torch.std(X, dim=0)+1e-8
    Xs = X.sub(mu.expand_as(X)).div(std.expand_as(X))
    R = 1/(n-1) * torch.matmul(Xs.t(), Xs)
    R = torch.clamp(R, -1.0, 1.0)
    return R

def mean_xcorr(X):
    """Compute mean pairwise cross-correlation from samples of a random vector

    Computes the mean absolute value of the non-diagonal elements of corrcoef(X)

    Parameters
    ----------
    X : 2D torch.Tensor
        N x d matrix consisting of N samples of a d-element vector

    Returns
    ----------
    float
        The mean pairwise cross-correlation of X
    """
    assert type(X) is torch.Tensor
    assert X.dim() == 2
    n, m = X.size()
    R = corrcoef(X)
    return torch.sum(torch.abs(R - eye_like(R)*R)) / (m**2 - m)

class Reshape(torch.nn.Module):
    """Module that returns a view of the input which has a different size

    Parameters
    ----------
    args : int...
        The desired size
    """
    def __init__(self, *args):
        super().__init__()
        self.shape = args
    def __repr__(self):
        s = self.__class__.__name__
        s += '{}'.format(self.shape)
        return s
    def forward(self, input):
        return input.view(*self.shape)

def one_hot(x, depth):
    """Convert a batch of indices to a batch of one-hot vectors

    Parameters
    ----------
    depth : int
        The length of each output vector
    """
    batch_size = x.shape[0]
    ix = x.unsqueeze(1)
    return torch.zeros(batch_size, depth).scatter_(1, ix, 1.0)

class Network(torch.nn.Module):
    """Module that when printed shows its total number of parameters
    """
    def __str__(self):
        s = super().__str__()+'\n'
        n_params = 0
        for p in self.parameters():
            n_params += np.prod(p.size())
        s += 'Total params: {}'.format(n_params)
        return s

    def summary(self):
        s = str(self)
        print(s)
        return s

def test_corrcoef():
    np.random.seed(0)
    N = 100
    d = 4
    X = np.random.randn(N,d)
    R_np = np.corrcoef(X,rowvar=False)
    X_t = torch.tensor(X)
    R_t = corrcoef(X_t).numpy()
    assert np.all(np.isclose(R_t, R_np))

    # Add two columns of zeros
    X = np.concatenate([X, np.zeros([N,1]), np.zeros([N,1])],axis=1)
    d = X.shape[1]
    X_t = torch.tensor(X)
    R_t = corrcoef(X_t).numpy()
    assert np.all(np.isclose(R_t[:-2,:-2], R_np))
    assert np.all(np.isclose(R_t[-2:,:], 0.0))
    assert np.all(np.isclose(R_t[:,-2:], 0.0))


def test_xcorr():
    np.random.seed(0)
    N = 8
    d = 4
    X = np.random.randn(N,d)
    R_np = np.corrcoef(X, rowvar=False)
    xc_np = np.sum(np.abs(R_np - np.eye(d)*(R_np))) / (d**2 - d)
    X_t = torch.tensor(X, dtype=torch.float32)
    xc_t = mean_xcorr(X_t).item()
    assert np.isclose(xc_np, xc_t)

def main():
    test_corrcoef()
    test_xcorr()
    print('All tests passed.')

if __name__ == '__main__':
    main()
