import sys
from typing import Union
import pandas as pd
import scipy as sp
import numpy as np
from numpy.testing import *
import torch

sys.path.append('/workspace')
from utils import estimate_coef_by_time_series_regression, estimate_residual_by_time_series_regression

identity = np.identity
inv = np.linalg.inv
pinv = np.linalg.pinv
sqrt = np.sqrt
diag = np.diag
sign = np.sign
mean = lambda a: np.mean(a, axis=0)
cov = lambda a: np.cov(a, rowvar=False).reshape(a.shape[1], a.shape[1])
qr = np.linalg.qr
tr = np.trace
kron = np.kron

class PCA_XC:
    """
    Principal Component Analysis for Cross-Sectionally Correlated Pricing Errors

    Parameters
    ----------
    X : panel of excess returns T by N matrix
    V : N by N matrix that defines the norm of alpha
    eta : regularization parameter
    K : number of factors
    orthogonalize_lambda : {True, False}
        If True, normalize Lambda to satisfy 1/N*Lambda.T @ Lambda = I_K.
    normalization_of_factors : {'no_normalization', 'orthogonal', 'uncorrelated'}
        If 'orthogonal', the estimated factors are orthogonal, i.e., the following is True.
            tmp = self.output["factors"].T @ self.output["factors"]
            np.allclose(tmp, np.diag(np.diag(tmp)))
        If 'uncorrelated', the estimated factors are uncorrelated, i.e., the following is True.
            np.allclose((self.output["factors"].T @ self.M1 @ self.output["factors"])/self.T, np.identity(self.K))
        If 'no_normalization', the estimator returns the original PCA rotation.
    max_iter : the maximum number of iterations for the alternating least squares (ALS) method
    compute_objvals : {True, False}
        If True, objective function values over the ALS iterations are computed and saved.
    compute_gradient_norm : {True, False}
        If True, Frobenius norms of gradients of the objective function over the ALS iterations are computed and saved.
    """

    def __init__(self, X:np.ndarray, V:np.ndarray, eta:float, K:int, orthogonalize_lambda:bool, normalization_of_factors:str, signnormalization:bool, max_iter:int, sanity_check_full_rank:bool=False, compute_objvals:bool=False, compute_gradient_norm:bool=False, check_rank_of_A:bool=False, V_is_diagonal:bool=False):
        self.V_is_diagonal = V_is_diagonal
        self.set_problem_specifiers(X, V, eta, K, max_iter)
        self.orthogonalize_lambda = orthogonalize_lambda
        self.normalization_of_factors = normalization_of_factors
        self.signnormalization = signnormalization
        self.sanity_check_full_rank = sanity_check_full_rank
        self.compute_objvals = compute_objvals
        self.compute_gradient_norm = compute_gradient_norm
        self.check_rank_of_A = check_rank_of_A


    def set_problem_specifiers(self, X:np.ndarray=None, V:np.ndarray=None, eta:float=None, K:int=None, max_iter:int=None):
        """
        This function sets parameters that specify the optimization problem, which includes X, V, eta, max_iter,
        and compute constant values that do not change if the "specifying variables" do not vary during ALS steps.
        """
        # Set the parameters that specify the optimization problem.
        if X is not None:
            self.X = X.copy()
            self.T, self.N = X.shape

        if V is not None:
            self.V = V.copy()
            self.N = len(V)

        if eta is not None:
            self.eta = eta

        if K is not None:
            self.K = K

        if max_iter is not None:
            self.max_iter = max_iter

        # Parsing
        X, V = self.X, self.V
        T, N, K, eta = self.T, self.N, self.K, self.eta
 
        # Sanity checks
        if X is not None:
            assert X.shape == (T, N), f"X.shape={X.shape} and (T,N)=({T},{N})"

        if V is not None:
            assert V.ndim == 2, f"{V.ndim} must be 2."
            assert V.shape == (N, N)
            assert_array_equal(V, V.T) # symmetry
        
        # Compute constant values
        if X is not None: # , then T is provided from the size of X.
            self.P1 = np.full((T,T), 1.0/T) # projection matrix onto the space of scalar multiple of all-one vector
            self.M1 = identity(T) - self.P1 # annihilation matrix
        
        if (X is not None) or (V is not None) or (eta is not None):
            # In this case, X_transformed has to be updated.
            if self.V_is_diagonal:
                tmp = (eta * np.diag(V) - 1).reshape(1,-1) # tmp is an 1 x n array.
                self.X_transformed = X + (self.apply_P1(X) * tmp)
            else:
                self.X_transformed = X + self.apply_P1(X) @ (eta * V - identity(N))
            

    def apply_P1(self, X:np.ndarray) -> np.ndarray:
        T = X.shape[0]
        out = np.mean(X, axis=0).reshape(1,-1)
        out = np.repeat(out, repeats=T, axis=0)
        return out


    def demean(self, Y:np.ndarray) -> np.ndarray:
        avg_Y = np.mean(Y, axis=0).reshape(1,-1)
        Y_demeaned = Y - avg_Y
        return Y_demeaned
    
    
    def als(self, Lambda_init:np.ndarray=None, F_init:np.ndarray=None, debug:bool=False):
        """
        Implimentation of the alternating least squares method
        """
        self.obj_vals = np.array([])
        self.obj_grad_norm = np.array([])
        self.obj_grad_wrt_F = np.array([])
        self.obj_grad_wrt_Lambda = np.array([])
        self.Lambda_iterates = []
        self.F_iterates = []

        # Initialization
        Lambda = np.eye(self.N, self.K) if Lambda_init is None else Lambda_init
        F = np.eye(self.T, self.K) if F_init is None else F_init
            
        self.Lambda_iterates.append(Lambda)
        self.F_iterates.append(F)

        # OLS updates
        for j in range(self.max_iter):
            if self.sanity_check_full_rank:
                assert np.linalg.matrix_rank(Lambda) == self.K, f"Rank of Lambda is {np.linalg.matrix_rank(Lambda)}. It should be K={self.K}."
                assert np.linalg.matrix_rank(F) == self.K, f"Rank of Lambda is {np.linalg.matrix_rank(F)}. It should be K={self.K}."

            if self.compute_objvals:
                self.obj_vals = np.append(self.obj_vals, self.compute_obj(F, Lambda))

            if self.compute_gradient_norm:
                self.obj_grad_norm = np.append(self.obj_grad_norm, self.compute_norm_grad(F, Lambda))

            if debug:
                self.Lambda = Lambda
                self.F = F

            Lambda_old = Lambda
            F_old = F

            # Update
            F = self.als_update_F(Lambda)
            Lambda = self.als_update_Lambda(F)

            if debug:
                # This if-statement is for sanity check.
                # The gradient has to be very close to 0.
                tmp_F = sqrt(np.sum(self.compute_grad_wrt_F(F, Lambda)**2))
                tmp_L = sqrt(np.sum(self.compute_grad_wrt_Lambda(F, Lambda)**2))

                self.obj_grad_wrt_F = np.append(self.obj_grad_wrt_F, tmp_F)
                self.obj_grad_wrt_Lambda = np.append(self.obj_grad_wrt_Lambda, tmp_L)

                self.Lambda_iterates.append(Lambda)
                self.F_iterates.append(F)


        # Normalize
        F, Lambda = self.normalize(F, Lambda)

        if self.compute_objvals:
            self.obj_vals = np.append(self.obj_vals, self.compute_obj(F, Lambda))

        if self.compute_gradient_norm:
            self.obj_grad_norm = np.append(self.obj_grad_norm, self.compute_norm_grad(F, Lambda))

        if debug:
            self.Lambda = Lambda
            self.F = F

        return F, Lambda


    def als_update_F(self, Lambda:np.ndarray) -> np.ndarray:
        """
        Input
        -----
        Lambda : (N x K) np.ndarray

        Output
        ------
        F_new : (T x K) np.array
        """
        if self.V_is_diagonal:
            v = diag(self.V).reshape(1,-1) # v is an 1 x N array
            A = kron(Lambda.T @ Lambda, self.M1) + kron((self.eta * Lambda.T * v) @ Lambda, self.P1)
        else:
            A = kron(Lambda.T @ Lambda, self.M1) + kron(self.eta * Lambda.T @ self.V @ Lambda, self.P1)

        A = (A + A.T)/2
        b = self.X_transformed @ Lambda
        b = b.reshape((-1,1), order='F') # vec operation
        
        if self.check_rank_of_A:
            if np.linalg.matrix_rank(A, hermitian=True) == self.T * self.K: # A has full-rank, i.e., rank(A)=TK.
                vec_F_new = sp.linalg.solve(A, b, assume_a="pos", check_finite=False)
            else:
                vec_F_new, _, _, _ = sp.linalg.lstsq(A, b, check_finite=False, lapack_driver='gelsy')
        else:
            vec_F_new = sp.linalg.solve(A, b, assume_a="pos", check_finite=False)

        F_new = vec_F_new.reshape((self.T, self.K), order='F')

        return F_new


    def als_update_Lambda(self, F:np.ndarray) -> np.ndarray:
        """
        Input
        -----
        F : (T x K) np.array

        Output
        ------
        Lambda_new : (N x K) np.ndarray
        """
        A = kron(identity(self.N), F.T @ self.demean(F)) + kron(self.eta * self.V, F.T @ self.apply_P1(F))
        A = (A + A.T)/2
        b = F.T @ self.X_transformed
        b = b.reshape((-1,1), order='F') # vec operation

        if self.check_rank_of_A:
            if np.linalg.matrix_rank(A, hermitian=True) == self.N * self.K: # A has full-rank, i.e., rank(A)=NK.
                vec_Lambda_new_T = sp.linalg.solve(A, b, assume_a="pos", check_finite=False)
            else:
                vec_Lambda_new_T, _, _, _ = sp.linalg.lstsq(A, b, check_finite=False, lapack_driver='gelsy')
        else:
            vec_Lambda_new_T = sp.linalg.solve(A, b, assume_a="pos", check_finite=False)

        Lambda_new_T = vec_Lambda_new_T.reshape((self.K, self.N), order='F')
        Lambda_new = Lambda_new_T.T

        return Lambda_new

    
    def normalize(self, F:np.ndarray, Lambda:np.ndarray) -> (np.ndarray, np.ndarray):
        """
        This function normalizes Lambda and F.
        """
        # 1. Make Lambda orthonormal
        if self.orthogonalize_lambda:
            S, L, Tt = np.linalg.svd(Lambda, full_matrices=False)
            Lambda = (Lambda @ Tt.T) * (1/L) * sqrt(self.N)
            F = (F @ Tt.T) * L / sqrt(self.N)

        # 2. Normalization of the factors F
        if self.normalization_of_factors == 'orthogonal':
            L, U = np.linalg.eigh(F.T @ F)
            Rotation = U[:,::-1]
            F = F @ Rotation
            Lambda = Lambda @ Rotation
        elif self.normalization_of_factors == 'uncorrelated':
            L, U = np.linalg.eigh(F.T @ self.demean(F))
            Rotation = U[:,::-1]
            F = F @ Rotation
            Lambda = Lambda @ Rotation
        elif self.normalization_of_factors == 'cov_of_F_equals_identity':
            L, U = np.linalg.eigh(F.T @ F / self.K)
            L, U = L[::-1], U[:,::-1]
            Rotation = U * np.sqrt(1/L).reshape(1,-1)
            F = F @ Rotation
            Rotation2 = U * np.sqrt(L).reshape(1,-1)
            Lambda = Lambda @ Rotation2

        elif self.normalization_of_factors != 'no_normalization':
            raise AssertionError(f'{self.normalization_of_factors}')

        # 3. Normalize the signs of the factors. This makes each factor have positive mean.
        if self.signnormalization:
            signnormalization = np.sign(np.mean(F, axis=0))
            F = F * signnormalization[np.newaxis,:] # F = F @ diag(Rotation)
            Lambda = Lambda * signnormalization[np.newaxis,:] # Lambda = Lambda @ diag(Rotation)

        return F, Lambda


    def compute_obj(self, F:np.ndarray, Lambda:np.ndarray) -> float:
        X_mean = np.mean(self.X, axis=0).reshape((-1,1))
        F_mean = np.mean(F, axis=0).reshape((-1,1))

        f = 1.0/self.N/self.T * np.linalg.norm(self.demean(self.X - F @ Lambda.T), ord='fro') ** 2
        if self.V_is_diagonal:
            v = diag(self.V).reshape(1,-1) # v is an 1 x N array
            regul = self.eta/self.N * ((X_mean - Lambda @ F_mean).T * v) @ (X_mean - Lambda @ F_mean)
        else:
            regul = self.eta/self.N * (X_mean - Lambda @ F_mean).T @ self.V @ (X_mean - Lambda @ F_mean)
        phi = f + regul[0,0]

        return phi


    def compute_grad_wrt_F(self, F:np.ndarray, Lambda:np.ndarray) -> np.ndarray:
        X_FL = self.X - F @ Lambda.T
        out = self.demean(X_FL)
        if self.V_is_diagonal:
            v = diag(self.V).reshape(1,-1) # v is an 1 x N array
            out += self.eta * self.apply_P1(X_FL) * v
        else:
            out += self.eta * self.apply_P1(X_FL) @ self.V
        out = out @ Lambda / (self.N * self.T)
        return out


    def compute_grad_wrt_Lambda(self, F:np.ndarray, Lambda:np.ndarray) -> np.ndarray:
        X_FL = self.X - F @ Lambda.T
        out = self.demean(X_FL)
        if self.V_is_diagonal:
            v = diag(self.V).reshape(1,-1) # v is an 1-d array
            out += self.eta * self.apply_P1(X_FL) * v
        else:
            out += self.eta * self.apply_P1(X_FL) @ self.V
        out = F.T @ out / (self.N * self.T)
        return out.T


    def compute_norm_grad(self, F:np.ndarray, Lambda:np.ndarray) -> float:
        g_F = self.compute_grad_wrt_F(F, Lambda)
        g_L = self.compute_grad_wrt_Lambda(F, Lambda)
        out = sqrt(np.sum(g_F**2) + np.sum(g_L**2))
        return out


    def run(self, Lambda_init:np.ndarray=None, F_init:np.ndarray=None, debug:bool=False):
        T, N, K, eta = self.T, self.N, self.K, self.eta
        X_df = pd.DataFrame(self.X)

        Fhat, Lambdahat = self.als(Lambda_init=Lambda_init, F_init=F_init, debug=debug)
        
        # Compute SDF and do regression.
        output = self.ols_analysis(Lambda=Lambdahat, F=Fhat)

        return output


    def ols_analysis(self, Lambda:Union[np.ndarray, torch.Tensor], F:Union[np.ndarray, torch.Tensor]) -> dict:
        # Parsing
        if isinstance(self.X, torch.Tensor):
            X_df = pd.DataFrame(self.X.cpu().detach().numpy())
        elif isinstance(self.X, np.ndarray):
            X_df = pd.DataFrame(self.X)

        if isinstance(Lambda, torch.Tensor):
            Lambda = Lambda.cpu().detach().numpy()

        if isinstance(F, torch.Tensor):
            F = F.cpu().detach().numpy()

        N, T, K = self.N, self.T, self.K

        # Sanity check
        assert Lambda.shape == (N, K), f"Lambda.shape={Lambda.shape}, but N={N}, K={K}"
        assert F.shape == (T, K), f"F.shape={F.shape}, but T={T}, K={K}"

        # Compute SDF and do regression.
        beta = {}
        residual = {}
        SDFweightsassets = {}
        SDF = np.zeros((T,K))
        a = np.zeros((N,K))
        for k in range(1,K+1):
            factors = F[:,:k]
            factors_df = pd.DataFrame(factors)

            # Mean variance optimization
            SDFweights = np.linalg.inv(cov(factors)) @ np.mean(factors, axis=0).T
            SDF[:,k-1] = factors @ SDFweights
            SDFweightsassets[k] = Lambda[:,:k] @ np.linalg.inv(Lambda[:,:k].T @ Lambda[:,:k]) @ SDFweights

            # Time-series regressions of data on the factors estimated by Regul-PCA
            param_tmp = estimate_coef_by_time_series_regression(X_df, factors_df, rf=None, intercept=True, min_non_missing=0)
            a[:,k-1] = param_tmp.iloc[0].values.T
            beta[k] = param_tmp.iloc[1:, :].values.T

            resid_tmp = estimate_residual_by_time_series_regression(X_df, factors_df, rf=None, intercept=True, min_non_missing=0)
            residual[k] = resid_tmp.values

        # Reformulate outputs
        output = {}
        output['loadings'] = Lambda # loadings
        output['factors'] = F # factor time series
        # output['eigenval'] = DDWPCA # eigenvalues of Regul-PCA matrix
        output['SDF time-series'] = SDF # SDF time-series
        output['SDF weights'] = SDFweightsassets # SDF weights
        output['betas'] = beta         # estimated betas     of time-series regression of returns on Regul-PCA factors
        output['alphas'] = a           # estimated alphas    of time-series regression of returns on Regul-PCA factors
        output['residuals'] = residual # estimated residuals of time-series regression of returns on Regul-PCA factors
        # output['weights'] = factorweight # factor portfolio weights
        self.output = output

        return output

