#!/usr/bin/env python
# coding: utf-8
import os, pickle, json, sys, time
import pandas as pd
import numpy as np
from numpy.testing import *
import matplotlib.pyplot as plt

# import packages from /workspace/ (current working directory)
sys.path.append('/workspace/')
from utils import sharpe, estimate_residual_by_time_series_regression, estimate_coef_by_time_series_regression, get_V_matrices

# import packages from /workspace/Estimators
sys.path.append('/workspace/Estimators/')
from rppca_adj import RPPCAadj as RPPCA
from pcaxc import PCA_XC

cov = lambda a: np.cov(a, rowvar=False).reshape(a.shape[1], a.shape[1])

def get_scenario(r_tst, r_ffc, factors):
    factors_rhs = list(factors)

    # 1. Estimate betas of r_tst by regressing on a three factor model
    r_3fac = r_ffc[factors_rhs]
    excess_r = r_tst.subtract(r_ffc["RF"], axis=0)
    estm_params = estimate_coef_by_time_series_regression(excess_r,
                                                          r_3fac,
                                                          rf=None,
                                                          intercept=False,
                                                          min_non_missing=0)
    B = estm_params.loc[factors_rhs].T.values # N by K


    # 2. Estimate residuals of r_tst regressed out by a three factor model
    resid = estimate_residual_by_time_series_regression(excess_r,
                                                        r_3fac,
                                                        rf=None,
                                                        intercept=False,
                                                        min_non_missing=0)
    
    # 3. Save the descriptive statistics.
    descriptive_stats = {"T": r_tst.shape[0], "N": r_tst.shape[1], "K": len(factors_rhs)}
    
    # 3-1. betas
    descriptive_stats['Sigma_B'] = cov(B)
    descriptive_stats['mu_B'] = np.mean(B, axis=0)
    
    # 3-2. factors
    descriptive_stats['Sigma_F'] = cov(r_3fac.values)
    descriptive_stats['mu_F'] = np.mean(r_3fac.values, axis=0)
    
    # 3-3. factors
    descriptive_stats['Sigma_U'] = cov(resid)    
    
    # 3-4. test assets
    descriptive_stats['mu_R'] = np.mean(r_tst.values, axis=0)
    descriptive_stats['Sigma_R'] = cov(r_tst.values)
    descriptive_stats['Date range'] = f'{r_ffc.index.min()}-{r_ffc.index.max()}'
    
    return descriptive_stats



if __name__ == "__main__":
    strt_time = time.time()

    #############################################################################
    # Read a configuration file and parse some important info.
    #############################################################################
    if len(sys.argv) > 1:
        path_config_file = sys.argv[1]
    else:
        path_config_file = "/workspace/Config/numerical_performance.json"

    with open(path_config_file, "r") as f:
        config = json.load(f)
    do_sanity_check = config["do_sanity_check"]

    np.random.seed(config["random_seed"]) # Set random seed for reproducibility.
    reference_period = config["reference_period"]
    do_sanity_check = config["do_sanity_check"]

    num_random_init = config["num_random_init"]

    #############################################################################
    # Make a folder that will contain experiment results.
    #############################################################################
    out_directory = config["outpath_dir"]
    if not os.path.exists(out_directory):
        os.makedirs(out_directory)


    ############################################################################################################
    ######################################### 1. Read Fama-French data #########################################
    ############################################################################################################
    # 1.1. Read FFC factors: used for the RHS variables
    ffc_dat = []
    # for info in [info_ff5, info_momentum]:
    for _, info in config["ffc_monthly_file"].items():
        tmp = pd.read_csv(info['path'], skiprows=info['skiprows'], skipfooter=info['skipfooter'], index_col=0, engine='python', dtype=float)
        tmp = tmp.loc[reference_period[0] : reference_period[1]]
        tmp /= 100
        ffc_dat.append(tmp)

    r_ffc = pd.concat(ffc_dat, axis=1)
    r_ffc.columns = r_ffc.columns.map(lambda x: x.strip())


    ############################################################################################################
    ######################################### 2. Test asset return data ########################################
    ############################################################################################################
    # Read data on the LHS
    info = config["test_asset"]
    df = pd.read_csv(info["path"], skiprows=info["skiprows"], nrows=info["nrows"], index_col=0, encoding='mac-roman', engine="python", dtype=float)
    df = df.loc[reference_period[0] : reference_period[1]]
    r_tst1 = df/100

    Y = r_tst1.values

    # 1.3. Data sanity check
    if do_sanity_check:
        assert_array_equal(r_ffc.applymap(lambda x: x not in [-99.99, -999]), True, 'FFC data set contains N/A value.')
        assert_array_equal(r_ffc.isna(), False, 'FFC data set contains N/A value.')

        assert_array_equal(r_tst1.index, r_ffc.index, 'FFC and test data should match time index.')
        assert_array_equal(r_tst1.applymap(lambda x: x not in [-99.99, -999]), True, 'Test data set contains N/A value.')
        assert_array_equal(r_tst1.isna(), False, 'Test data set contains N/A value.')


    ############################################################################################################
    ######################################### 3. Set parameters ################################################
    ############################################################################################################
    K = 4
    eta = config["eta"]

    # My wrapper
    rp_pca = RPPCA(Y,
                   eta=eta,
                   K=K,
                   orthogonalize_lambda=True,
                   normalization_of_factors='uncorrelated',
                   signnormalization=True)

    pca_xc = PCA_XC(Y,
                    V=np.identity(25),
                    eta=eta,
                    K=K,
                    orthogonalize_lambda=True,
                    normalization_of_factors='uncorrelated',
                    signnormalization=True,
                    max_iter=config["max_iter_suboptimality"],
                    sanity_check_full_rank=True,
                    compute_objvals=True,
                    compute_gradient_norm=True)


    ############################################################################################################
    # For the case where V is an inverse of an arbitrary covariance matrix.
    ############################################################################################################
    descriptive_stats1 = get_scenario(r_tst1, r_ffc, ('Mkt-RF','SMB','HML'))
    descriptive_stats2 = get_scenario(r_tst1, r_ffc, ('Mkt-RF','RMW','CMA'))

    descriptive_stats1 = get_scenario(r_tst1, r_ffc, ('Mkt-RF','SMB','HML'))
    descriptive_stats2 = get_scenario(r_tst1, r_ffc, ('Mkt-RF','RMW','CMA'))

    Sigma_U1 = descriptive_stats1['Sigma_U'] 
    Sigma_U2 = descriptive_stats2['Sigma_U']

    Sigma_U1 /= np.mean(np.diag(Sigma_U1))
    Sigma_U2 /= np.mean(np.diag(Sigma_U2))

    V1 = np.linalg.inv(Sigma_U1)
    V2 = np.linalg.inv(Sigma_U2)

    V1 = (V1 + V1.T)/2
    V2 = (V2 + V2.T)/2


    figsize = 3
    fontsize = 19

    fig, axes = plt.subplots(1, 4, figsize=(figsize*5.5, figsize*1))
    fig.subplots_adjust(hspace=0.1, wspace=0.3)


    for i in range(num_random_init):
        Lambda_init = np.random.randn(25, K)
        F_init = np.random.randn(60, K)
        
        ##########################################################################################
        # 1) V = I_N -- it is able to compute suboptimality.
        ##########################################################################################
        # Run ALS method to get local minimum and non-increasing sequence
        pca_xc.max_iter = config["max_iter_suboptimality"]
        pca_xc.set_problem_specifiers(V=np.identity(25))
        
        pca_xc.run(Lambda_init=Lambda_init, F_init=F_init, debug=True)

        phi_n = pca_xc.obj_vals

        Lambda_iterates = pca_xc.Lambda_iterates
        F_iterates = pca_xc.F_iterates

        # Compute the exact minimum value obtained from RP-PCA
        rp_pca.run()
        Lambdahat_rp, Fhat_rp = rp_pca.output['loadings'], rp_pca.output['factors']
        phi_min = pca_xc.compute_obj(Fhat_rp, Lambdahat_rp)

        # Compute distance between the exact minimzer and the iterates.
        dist_between_iterates = []
        for LL, FF in zip(Lambda_iterates, F_iterates):
            tmp = Fhat_rp@Lambdahat_rp.T - FF@LL.T
            dist_between_iterates.append(np.sqrt((tmp**2).sum()))
        dist_between_iterates = np.array(dist_between_iterates)

        # Plot.
        axes[0].semilogy(np.clip(phi_n - phi_min, a_min=1e-20, a_max=1e16), label='$\phi_n$')
        axes[0].set_title('(a) $\phi_n - \phi_*$', fontsize=fontsize)
        axes[0].xaxis.set_tick_params(labelsize=fontsize-2)
        axes[0].yaxis.set_tick_params(labelsize=fontsize-2)
        axes[0].grid(True)
        
        axes[1].semilogy(dist_between_iterates, label='$\||\\nabla\phi_n\||$')
        axes[1].set_title('(b) $||F_n \Lambda_n^T - F_* \Lambda_*^T||_F$', fontsize=fontsize)
        axes[1].xaxis.set_tick_params(labelsize=fontsize-2)
        axes[1].yaxis.set_tick_params(labelsize=fontsize-2)
        axes[1].grid(True)
        
        
        ##########################################################################################
        # 2) V /= I_N: V is an inverse of an arbitrary covariance matrix.
        # It is not able to compute suboptimality.
        ##########################################################################################
        pca_xc.max_iter = config["max_iter_objective_function_value"]

        for j, (factors, V) in enumerate(zip([('Mkt-RF','SMB','HML'), ('Mkt-RF','RMW','CMA')], [V1, V2])):
            pca_xc.set_problem_specifiers(V=V)
            print(np.trace(np.linalg.inv(pca_xc.V)))

            # Run ALS method to get local minimum and non-increasing sequence
            pca_xc.run(Lambda_init=Lambda_init, F_init=F_init, debug=True)
            phi_n = pca_xc.obj_vals

            axes[2+j].semilogy(phi_n, label=f'$\eta={eta}$')
            if j == 0:
                axes[2+j].set_title(f'(c) $\phi_n$', fontsize=fontsize)
            else:
                axes[2+j].set_title(f'(d) $\phi_n$', fontsize=fontsize)
            axes[2+j].xaxis.set_tick_params(labelsize=fontsize-2)
            axes[2+j].yaxis.set_tick_params(labelsize=fontsize-2)
            axes[2+j].grid(True)


    # Set legend and x-label
    fig.supxlabel('Number of iterations', fontsize=fontsize, x=0.5, y=-0.12)

    # Set title and legend
    fig.legend(
        [f'init#{i+1}' for i in range(num_random_init)],
        fontsize=fontsize-2,
        ncol=1,
        loc="upper right",
        bbox_to_anchor=(0.282, 0.9, 0, 0),
        handlelength=1
    )


    ########################################################################
    # Save figure
    ########################################################################
    filepath = os.path.join(out_directory, "numerical_performance.eps")
    fig.savefig(filepath, format="eps", dpi=1200, bbox_inches="tight")
    print(f"The figures are saved to: {filepath}")

    elapsed_time = int(time.time() - strt_time)
    print(f"Done. -- Elapsed time: {elapsed_time//86400} day {elapsed_time%86400//3600} hr {elapsed_time%86400%3600//60} min {elapsed_time%86400%3600%60} sec")

