#!/usr/bin/env python
# coding: utf-8

import numpy as np
import sys
from lettau2020jeconom import RPPCA

inv = np.linalg.inv
sqrt = np.sqrt
diag = np.diag
sign = np.sign
mean = lambda a: np.mean(a, axis=0)
cov = lambda a: np.cov(a, rowvar=False).reshape(a.shape[1], a.shape[1])
qr = np.linalg.qr


class RPPCAadj(RPPCA):
    """
    For fair comparison between RP-PCA and PCA-XC, this class adds two differences
    from the class RPPCA() in /utilities/lettau2020jeconom/lettau2020jeconom.py

    1. 'eta - 1' is input to RPPCA() for the fair comparison.
    Relative significance of the regularization term in the objective function of
    RP-PCA and PCA-XC is the same if [gamma (of RP-PCA) = eta (of PCA-XC) - 1].

    2. PCA_XC()'s implementation for the normalization is added to RPPCA() for fair comparison.
    """
    def __init__(self, X:np.ndarray, eta:float, K:int, orthogonalize_lambda:bool, normalization_of_factors:str, signnormalization:bool):
        super().__init__(X, self.set_gamma_from_eta(eta), K, 0, 0, 0)
        self.orthogonalize_lambda = orthogonalize_lambda
        self.normalization_of_factors = normalization_of_factors
        self.signnormalization = signnormalization

        self.M1 = np.identity(self.T) - np.full((self.T,self.T), 1.0/self.T) # Annihilation matrix
        

    def run(self):
        X = self.X
        T, N, K, gamma = self.T, self.N, self.K, self.gamma
        
        # Run code from RPPCA
        output = super().run()
        Fhat = output['factors'] # factor time series
        Lambdahat = output['loadings'] # loadings
        DDWPCA = output['eigenval'] # eigenvalues of RP-PCA matrix
        factorweight = output['weights'] # factor portfolio weights
        
        # Normalize to be consistency with my normalization.
        Fhat, Lambdahat = self.normalize(F=Fhat, Lambda=Lambdahat)
        
        factors = {}
        beta = {}
        residual = {}
        SDFweightsassets = {}
        SDF = np.zeros((T,K))
        a = np.zeros((N,K))
        for k in range(1,K+1):
            factors[k] = Fhat[:,:k]

            # Mean variance optimization
            SDFweights = inv(cov(factors[k])) @ mean(factors[k]).T
            SDF[:,k-1] = factors[k] @ SDFweights
            SDFweightsassets[k] = Lambdahat[:,:k] @ inv(Lambdahat[:,:k].T @ Lambdahat[:,:k]) @ SDFweights
            
            # Time-series regressions
            tmp = np.append(np.ones((T,1)), factors[k][:,:k], axis=1)
            dummy = inv(tmp.T @ tmp) @ tmp.T @ X
            residual[k] = X - tmp @ dummy
            a[:,k-1] = dummy[0,:].T
            beta[k] = dummy[1:k+1,:].T
  

        output = {}
        output['loadings'] = Lambdahat # loadings
        output['factors'] = Fhat # factor time series
        output['eigenval'] = DDWPCA # eigenvalues of RP-PCA matrix
        output['SDF time-series'] = SDF # SDF time-series
        output['SDF weights'] = SDFweightsassets # SDF weights
        output['betas'] = beta # time-series regression betas on RP-PCA factors
        output['alphas'] = a # time-series regression alpha
        output['residuals'] = residual #  residuals from time-series regression
        output['weights'] = factorweight # factor portfolio weights
        self.output = output

        return output
    
    
    def normalize(self, F:np.ndarray, Lambda:np.ndarray) -> (np.ndarray, np.ndarray):
        """
        This function normalizes Lambda and F.
        """
        # 1. Make Lambda orthonormal
        if self.orthogonalize_lambda:
            S, L, Tt = np.linalg.svd(Lambda, full_matrices=False)
            Lambda = (Lambda @ Tt.T) * (1/L) * sqrt(self.N)
            F = (F @ Tt.T) * L / sqrt(self.N)

        # 2. Normalization of the factors F
        if self.normalization_of_factors == 'orthogonal':
            L, U = np.linalg.eigh(F.T @ F)
            Rotation = U[:,::-1]
            F = F @ Rotation
            Lambda = Lambda @ Rotation
        elif self.normalization_of_factors == 'uncorrelated':
            L, U = np.linalg.eigh(F.T @ self.M1 @ F)
            Rotation = U[:,::-1]
            F = F @ Rotation
            Lambda = Lambda @ Rotation
        elif self.normalization_of_factors != 'no_normalization':
            raise AssertionError(f'{self.normalization_of_factors}')

        # 3. Normalize the signs of the factors. This makes each factor have positive mean.
        if self.signnormalization:
            signnormalization = np.sign(np.mean(F, axis=0))
            F = F * signnormalization[np.newaxis,:] # F = F @ diag(Rotation)
            Lambda = Lambda * signnormalization[np.newaxis,:] # Lambda = Lambda @ diag(Rotation)

        return F, Lambda
    

    def set_problem_specifiers(self, X:np.ndarray=None, eta:float=None, K:int=None):
        """
        This function sets parameters, specifying the optimization problem
        of the RP-PCA estimator, which includes X, gamma and K.
        """
        # Set the parameters that specify the optimization problem.
        if X is not None:
            self.X = X.copy()
            self.T, self.N = X.shape
            self.M1 = np.identity(self.T) - np.full((self.T,self.T), 1.0/self.T) # Annihilation matrix

        if eta is not None:
            self.set_gamma_from_eta(eta)

        if K is not None:
            self.K = K


    def set_gamma_from_eta(self, eta:float):
        self.gamma = eta - 1
        return self.gamma

