from __future__ import division, print_function
try:
    from pylab import plt
except ImportError:
    print('Unable to import pylab. R_pca.plot_fit() will not work.')
try:
    # Python 2: 'xrange' is the iterative version
    range = xrange
except NameError:
    # Python 3: 'range' is iterative - no need for 'xrange'
    pass

import numpy as np
import random
import cvxpy as cp
from scipy.stats import ortho_group

def GenerateMatrix(p, n, alpha, spike, scal = 1):
    """Inputs:
    p      # features
    n      # observations
    alpha  Degrees of freedom of Student distribution
    spike  list consisting of signal strengths
    scal   optional parameter to scale noise matrix by

    Outputs:
    X      data matrix UDV + scal*Z/sqrt(n)
    u      first population-level principal component

    cf. information-plus-noise model in Section 3.1
    """
    noise_mat = np.random.standard_t(alpha, size=(p,n))
    U = ortho_group.rvs(p)
    V = ortho_group.rvs(n)
    D = np.zeros((p,n))
    for i in range(min(p,n)):
        if i < len(spike):
            D[i,i] = spike[i]
        else:
            D[i,i] = 1
    return U @ D @ V + scal * noise_mat / n**0.5, U[:,0]

def Our_method(mat, R = 500, P = 50, N = 50, num_comps = 1, run = False):
    """Inputs:
    mat         data matrix in form #features x #observations
    R, P, N     three parameters of our algorithm
    num_comps   dimension of principal subspace to compute
    run         Boolean parameter

    Output if run = False:
    p x num_comps matrix [u_(num_comps), ..., u_1]
    where u_i is the i'th leading estimated principal component of mat

    Output if run = True (relevant only to Figure 3a):
    [v_1, v_2, ... v_R]
    where v_i is the estimating of leading principal component of mat with R=i, unless each
    convex optimizer fails, in which case v_1=...=v_R=(0,...,0)
    """
    
    p, n = mat.shape
    V = np.zeros((p,p))
    running_U = []
    for r in range(R):

        # Subselect observations
        counts = 1/np.linalg.norm(mat, axis=0)
        counts = counts / np.sum(counts)
        cols = np.random.choice(range(mat.shape[1]), size=N, replace=False, p=counts)
        Xt = mat[:,cols]
        
        if P == N:
            Q, R = np.linalg.qr(Xt)
            V = V + Q @ Q.T
        else:
            St, Ut = np.linalg.eigh(Xt @ Xt.T)
            V = V + Ut[:,-P:] @ Ut[:,-P:].T
        if run:
            S,U = np.linalg.eigh(V)
            running_U.append(U[:,-1])
    if run:
        return running_U
    else:
        S, U = np.linalg.eigh(V)
        return U[:,-num_comps:]


def ECA_method(mat, num_comps = 1):
    """Inputs:
    mat         data matrix in form #features x #observations
    num_comps   dimension of principal subspace to compute

    Output:
    p x num_comps matrix [u_(num_comps), ..., u_1]
    where u_i is the i'th leading estimated principal component of mat

    Ref:
    Fang Han and Han Liu (2015). "ECA: high-dimensional elliptical
    component analysis in non-Gaussian distributions." Journal of
    the American Statistical Association 113(521): 252-268.
    https://doi.org/10.1080/01621459.2016.1246366
    """
    p,n = mat.shape
    K = np.zeros((p,p))
    for i in range(n):
        for j in range(i):
            diff = mat[:,i]-mat[:,j]
            diff = diff / np.linalg.norm(diff)
            K = K + np.outer(diff,diff) * 2 / (n*(n-1))
    S, U = np.linalg.eigh(K)
    return U[:,-num_comps:]


def Sample_Cov_method(mat, num_comps = 1):
    """Inputs:
    mat         data matrix in form #features x #observations
    num_comps   dimension of principal subspace to compute

    Output:
    p x num_comps matrix [u_(num_comps), ..., u_1]
    where u_i is the i'th leading estimated principal component of mat
    """
    S, U = np.linalg.eigh(mat @ mat.T / mat.shape[1])
    return U[:,-num_comps:]



def Minsker_method(mat, NU=[0.5], NUM_GROUPS=10, num_comps = 1):
    """Inputs:
    mat         data matrix in form #features x #observations
    NU          list of values of nu, cf. Minsker (2015)
    NUM_GROUPS  parameter k of Minsker (2015)
    num_comps   dimension of principal subspace to compute

    Output:
    p x num_comps matrix [u_(num_comps), ..., u_1]
    where u_i is the i'th leading estimated principal component of mat
    """
    mat = mat/np.mean(mat) # for numerical stability
    size_groups = int(mat.shape[1] / NUM_GROUPS)

    # Collect in list "mats" the matrices to find geometric median of
    cols = list(range(mat.shape[1]))
    random.shuffle(cols)
    mats = []
    for i in range(NUM_GROUPS):
        X_trim = mat[:,[cols[j] for j in range(size_groups*i, size_groups*(i+1))]]
        mats.append(X_trim @ X_trim.T / X_trim.shape[1])
        mats[-1] = mats[-1].flatten()
    mats = np.array(mats).T

    # Solve convex optimization
    coefs = cp.Variable(mats.shape[1])
    c = cp.Variable(mats.shape[1])
    objective = cp.sum(c)
    constraints = [cp.SOC(c[i], mats @ coefs - mats[:,i]) for i in range(mats.shape[1])]
    contraints = constraints + [coefs >= 0, cp.sum(coefs) == 1, c >= 0]
    problem = cp.Problem(cp.Minimize(objective), constraints)

    # Try three convex optimizers
    failed = False
    try:
        problem.solve(solver = 'CLARABEL')
        coefs.value[0]
    except:
        try:
            problem.solve(solver = 'SCS')
            coefs.value[0]
        except:
            try:
                problem.solve(solver = 'ECOS')
                coefs.value[0]
            except:
                failed = True

    # Return zero vector if each optimizer failed
    results = []
    for nu in NU:
        if failed:
            results.append(np.zeros((mat.shape[0],1)))
            continue
        
        # Threshold according to NU
        alpha = []
        for i in range(mats.shape[1]):
            if coefs.value[i] >= nu/mats.shape[1]:
                alpha.append(coefs.value[i])
            else:
                alpha.append(0)
        alpha = np.array(alpha)
        alpha = alpha / np.sum(alpha)
        Cov = sum([alpha[i]*mats[:,i] for i in range(mats.shape[1])])
        Cov = np.reshape(Cov, (mat.shape[0], mat.shape[0]))

        S, U = np.linalg.eigh(Cov)
        results.append(U[:,-num_comps:])
    return results



# Remaining code is for Robust PCA and is copied from:
#  https://github.com/dganguli/robust-pca
class R_pca:

    def __init__(self, D, mu=None, lmbda=None):
        self.D = D
        self.S = np.zeros(self.D.shape)
        self.Y = np.zeros(self.D.shape)

        if mu:
            self.mu = mu
        else:
            self.mu = np.prod(self.D.shape) / (4 * np.linalg.norm(self.D, ord=1))

        self.mu_inv = 1 / self.mu

        if lmbda:
            self.lmbda = lmbda
        else:
            self.lmbda = 1 / np.sqrt(np.max(self.D.shape))

    @staticmethod
    def frobenius_norm(M):
        return np.linalg.norm(M, ord='fro')

    @staticmethod
    def shrink(M, tau):
        return np.sign(M) * np.maximum((np.abs(M) - tau), np.zeros(M.shape))

    def svd_threshold(self, M, tau):
        U, S, V = np.linalg.svd(M, full_matrices=False)
        return np.dot(U, np.dot(np.diag(self.shrink(S, tau)), V))

    def fit(self, tol=None, max_iter=1000, iter_print=100):
        iter = 0
        err = np.Inf
        Sk = self.S
        Yk = self.Y
        Lk = np.zeros(self.D.shape)

        if tol:
            _tol = tol
        else:
            _tol = 1E-7 * self.frobenius_norm(self.D)

        #this loop implements the principal component pursuit (PCP) algorithm
        #located in the table on page 29 of https://arxiv.org/pdf/0912.3599.pdf
        while (err > _tol) and iter < max_iter:
            Lk = self.svd_threshold(
                self.D - Sk + self.mu_inv * Yk, self.mu_inv)                            #this line implements step 3
            Sk = self.shrink(
                self.D - Lk + (self.mu_inv * Yk), self.mu_inv * self.lmbda)             #this line implements step 4
            Yk = Yk + self.mu * (self.D - Lk - Sk)                                      #this line implements step 5
            err = self.frobenius_norm(self.D - Lk - Sk)
            iter += 1
##            if (iter % iter_print) == 0 or iter == 1 or iter > max_iter or err <= _tol:
##                print('iteration: {0}, error: {1}'.format(iter, err))

        self.L = Lk
        self.S = Sk
        return Lk, Sk

    def plot_fit(self, size=None, tol=0.1, axis_on=True):

        n, d = self.D.shape

        if size:
            nrows, ncols = size
        else:
            sq = np.ceil(np.sqrt(n))
            nrows = int(sq)
            ncols = int(sq)

        ymin = np.nanmin(self.D)
        ymax = np.nanmax(self.D)
        print('ymin: {0}, ymax: {1}'.format(ymin, ymax))

        numplots = np.min([n, nrows * ncols])
        plt.figure()

        for n in range(numplots):
            plt.subplot(nrows, ncols, n + 1)
            plt.ylim((ymin - tol, ymax + tol))
            plt.plot(self.L[n, :] + self.S[n, :], 'r')
            plt.plot(self.L[n, :], 'b')
            if not axis_on:
                plt.axis('off')
                
def RPCA_method(mat, num_comps=1):
    """Performs Robust PCA. Copied from:
    https://github.com/dganguli/robust-pca
    """
    rpca = R_pca(mat)
    L, S = rpca.fit(max_iter=10000, iter_print=100)
    S, U = np.linalg.eigh(L @ L.T / L.shape[1])
    return U[:,-num_comps:]
