import numpy as np

eye = np.identity
inv = np.linalg.inv
sqrt = np.sqrt
diag = np.diag
sign = np.sign
mean = lambda a: np.mean(a, axis=0)
cov = lambda a: np.cov(a, rowvar=False).reshape(a.shape[1], a.shape[1])
qr = np.linalg.qr


class RPPCA:
    """
    RP-PCA estimation of factors

    Parameters
    ----------
    X : panel of excess returns T by N matrix
    gamma : risk premium parameter
    K : number of factors
    stdnorm : normalization by standard deviation.
        Use stdnorm=0 as standard.
    variancenormalization : normalize norm of factor or loadings.
        Use variancenormalization=0 as standard which normalizes the loadings to have unit length.
    orthogonalization
        orthogonalization=1 gives uncorrelated factors.
        Use orthogonalization=0 as the standard which returns the original RP-PCA rotation.
    """
    def __init__(self, X:np.ndarray, gamma:float, K:int, stdnorm:int, variancenormalization:int ,orthogonalization:int):
        assert X.ndim == 2, f"{X.ndim} must be 2."
        self.X = X
        self.T, self.N = X.shape
        self.gamma = gamma
        self.K = K
        self.stdnorm = stdnorm
        self.variancenormalization = variancenormalization
        self.orthogonalization = orthogonalization


    def run(self):
        X = self.X
        T, N, K, gamma = self.T, self.N, self.K, self.gamma
        variancenormalization = self.variancenormalization
        orthogonalization = self.orthogonalization

        # WN is the inverse of an N by N diagnonal matrix of standard deviations of each asset's return.
        if self.stdnorm == 1:
            WN = inv(sqrt(diag(diag(X.T @ (eye(T)-np.ones((T,T))/T) @ X / T))))
        else:
            WN = eye(N)

        WT = eye(T) + gamma*np.ones((T,T))/T

        # Generic estimator for general weighting matrices
        Xtilde = X @ WN

        # Covariance matrix with weighted mean
        VarWPCA = Xtilde.T @ WT @ Xtilde/T

        # Eigenvalue decomposition:
        DDWPCA, VWPCA = np.linalg.eigh(VarWPCA)

        # DDWPCA has the eigenvalues
        DDWPCA = DDWPCA[::-1]
        VWPCA = VWPCA[:,::-1]

        # Lambdahat are the eigenvectors after reverting the cross-sectional transformation
        Lambdahat = inv(WN.T) @ VWPCA[:,:K]

        # Normalizing the signs of the loadings
        Lambdahat = Lambdahat[:,:K] @ diag(sign(mean(X @ Lambdahat[:,:K] @ inv(Lambdahat[:,:K].T @ Lambdahat[:,:K]))))

        # Constructing the latent factors
        factorweight = Lambdahat[:,:K] @ inv(Lambdahat[:,:K].T @ Lambdahat[:,:K])
        factorweight = factorweight @ inv(sqrt(diag(diag(factorweight.T @ factorweight))))
        Fhat = X @ factorweight


        # if variancenormalization==1 then the loadings are scaled by the eigenvalues. Otherwise they have unit length.
        if (variancenormalization==1) and (orthogonalization==0):
            Lambdahat = Lambdahat @ diag(sqrt(DDWPCA[:K]))
            factorweight = factorweight @ inv(diag(sqrt(DDWPCA[:K])))
            Fhat = X @ factorweight
            # here the loadings are normalized to have unit length
        
        
        if (variancenormalization==1) and (orthogonalization==1):
            Q, R = qr((eye(T) - 1/T*np.ones((T,T))) @ Fhat / sqrt(T))
            Rotation = inv(R[:K,:K])
            factorweight = factorweight @ Rotation
            Fhat = X @ factorweight
            signnormalization = diag(sign(mean(Fhat)))
            Fhat = Fhat @ signnormalization
            factorweight = factorweight @ signnormalization
            Lambdahat = Lambdahat @ inv(Rotation) @ signnormalization


        if (variancenormalization==0) and (orthogonalization==1):
            Q, R = qr((eye(T) - 1/T*np.ones((T,T))) @ Fhat / sqrt(T))
            Rotation = inv(R[:K,:K]) @ diag(diag(R[:K,:K]))
            factorweight = factorweight @ Rotation
            factorweight = factorweight @ inv(sqrt(diag(diag(factorweight.T @ factorweight))))
            Fhat = X @ factorweight
            signnormalization = diag(sign(mean(Fhat)))
            Fhat = Fhat @ signnormalization
            factorweight = factorweight @ signnormalization
            Lambdahat = Lambdahat @ inv(Rotation) @ signnormalization
        

        factors = {}
        beta = {}
        residual = {}
        SDFweightsassets = {}
        SDF = np.zeros((T,K))
        a = np.zeros((N,K))
        for k in range(1,K+1):
            factors[k] = Fhat[:,:k]

            # Mean variance optimization
            SDFweights = inv(cov(factors[k])) @ mean(factors[k]).T
            SDF[:,k-1] = factors[k] @ SDFweights
            SDFweightsassets[k] = Lambdahat[:,:k] @ inv(Lambdahat[:,:k].T @ Lambdahat[:,:k]) @ SDFweights
            
            # Time-series regressions
            tmp = np.append(np.ones((T,1)), factors[k][:,:k], axis=1)
            dummy = inv(tmp.T @ tmp) @ tmp.T @ X
            residual[k] = X - tmp @ dummy
            a[:,k-1] = dummy[0,:].T
            beta[k] = dummy[1:k+1,:].T
  

        output = {}
        output['loadings'] = Lambdahat # loadings
        output['factors'] = Fhat # factor time series
        output['eigenval'] = DDWPCA # eigenvalues of RP-PCA matrix
        output['SDF time-series'] = SDF # SDF time-series
        output['SDF weights'] = SDFweightsassets # SDF weights
        output['betas'] = beta # time-series regression betas on RP-PCA factors
        output['alphas'] = a # time-series regression alpha
        output['residuals'] = residual #  residuals from time-series regression
        output['weights'] = factorweight # factor portfolio weights
        self.output = output

        return output
