#!/usr/bin/env python
# coding: utf-8
import os, pickle, json, sys, time
import pandas as pd
import numpy as np
from numpy.testing import *
import matplotlib.pyplot as plt
from itertools import product

# import packages from /workspace/ (current working directory)
sys.path.append('/workspace/')
from utils import sharpe, estimate_residual_by_time_series_regression, estimate_coef_by_time_series_regression, get_V_matrices

# import packages from /workspace/Estimators
sys.path.append('/workspace/Estimators/')
from rppca_adj import RPPCAadj as RPPCA
from pcaxc import PCA_XC
from Data.Kozak.anomolies import anomalies

cov = lambda a: np.cov(a, rowvar=False).reshape(a.shape[1], a.shape[1])

def get_scenario(r_tst, r_ffc, factors):
    factors_rhs = list(factors)

    # 1. Estimate betas of r_tst by regressing on a three factor model
    r_3fac = r_ffc[factors_rhs]
    excess_r = r_tst.subtract(r_ffc["RF"], axis=0)
    estm_params = estimate_coef_by_time_series_regression(excess_r,
                                                          r_3fac,
                                                          rf=None,
                                                          intercept=False,
                                                          min_non_missing=0)
    B = estm_params.loc[factors_rhs].T.values # N by K


    # 2. Estimate residuals of r_tst regressed out by a three factor model
    resid = estimate_residual_by_time_series_regression(excess_r,
                                                        r_3fac,
                                                        rf=None,
                                                        intercept=False,
                                                        min_non_missing=0)
    
    # 3. Save the descriptive statistics.
    descriptive_stats = {"T": r_tst.shape[0], "N": r_tst.shape[1], "K": len(factors_rhs)}
    
    # 3-1. betas
    descriptive_stats['Sigma_B'] = cov(B)
    descriptive_stats['mu_B'] = np.mean(B, axis=0)
    
    # 3-2. factors
    descriptive_stats['Sigma_F'] = cov(r_3fac.values)
    descriptive_stats['mu_F'] = np.mean(r_3fac.values, axis=0)
    
    # 3-3. factors
    descriptive_stats['Sigma_U'] = cov(resid)    
    
    # 3-4. test assets
    descriptive_stats['mu_R'] = np.mean(r_tst.values, axis=0)
    descriptive_stats['Sigma_R'] = cov(r_tst.values)
    descriptive_stats['Date range'] = f'{r_ffc.index.min()}-{r_ffc.index.max()}'
    
    return descriptive_stats


if __name__ == '__main__':
    strt_time = time.time()

    #############################################################################
    # Read a configuration file and parse some important info.
    #############################################################################
    if len(sys.argv) > 1:
        path_config_file = sys.argv[1]
    else:
        path_config_file = "/workspace/Config/numerical_performance.json"

    with open(path_config_file, "r") as f:
        config = json.load(f)
    do_sanity_check = config["do_sanity_check"]

    np.random.seed(config["random_seed"]) # Set random seed for reproducibility.
    reference_period = config["reference_period"]
    do_sanity_check = config["do_sanity_check"]

    num_random_init = config["num_random_init"]

    #############################################################################
    # Make a folder that will contain experiment results.
    #############################################################################
    out_directory = config["outpath_dir"]
    if not os.path.exists(out_directory):
        os.makedirs(out_directory)


    ############################################################################################################
    ######################################### 1. Read Fama-French data #########################################
    ############################################################################################################
    # 1.1. Read FFC factors: used for the RHS variables
    ffc_dat = []
    # for info in [info_ff5, info_momentum]:
    for _, info in config["ffc_monthly_file"].items():
        tmp = pd.read_csv(info['path'], skiprows=info['skiprows'], skipfooter=info['skipfooter'], index_col=0, engine='python', dtype=float)
        # tmp = tmp.loc[reference_period[0] : reference_period[1]]
        tmp /= 100
        ffc_dat.append(tmp)

    r_ffc = pd.concat(ffc_dat, axis=1)
    r_ffc = r_ffc.dropna().sort_index()
    r_ffc.columns = r_ffc.columns.map(lambda x: x.strip())


    ############################################################################################################
    ######################################### 2. Test asset return data ########################################
    ############################################################################################################
    # Read data on the LHS
    info = config["test_asset"]
    df = pd.read_csv(info["path"], skiprows=info["skiprows"], nrows=info["nrows"], index_col=0, encoding='mac-roman', engine="python", dtype=float)
    r_tst1 = df/100
    i = r_tst1.index.intersection(r_ffc.index)
    size_value_25 = r_tst1.subtract(r_ffc.loc[i,'RF'], axis=0).dropna()

    # 1.3. Data sanity check
    if do_sanity_check:
        assert_array_equal(r_ffc.applymap(lambda x: x not in [-99.99, -999]), True, 'FFC data set contains N/A value.')
        assert_array_equal(r_ffc.isna(), False, 'FFC data set contains N/A value.')

        # assert_array_equal(r_tst1.index, r_ffc.index, 'FFC and test data should match time index.')
        assert_array_equal(size_value_25.applymap(lambda x: x not in [-99.99, -999]), True, 'Test data set contains N/A value.')
        assert_array_equal(size_value_25.isna(), False, 'Test data set contains N/A value.')


    # Read anomoly portfolios
    rtrns = {}
    for anomaly in np.concatenate([x for x in anomalies.values()]):
        rt = pd.read_csv(os.path.join('./Data/Kozak/FT_portfolio_sorts-monthly-05FEB2020/monthly/', f'ret10_{anomaly}.csv'), index_col=0, parse_dates=True)
        rt.index = rt.index.map(lambda x: int(x.strftime('%Y%m')))
    #     rt = rt.loc[reference_period[0] : reference_period[1]]
        assert_array_equal(rt.isna(), False)
        rtrns[anomaly] = rt

        rt_longshort = rt['p10'] - rt['p1']

    rtrns = pd.concat(rtrns, axis=1)
    rtrns.columns.names = ['anomaly','decile']
    anomalies37 = rtrns.subtract(r_ffc.loc[rtrns.index, 'RF'], axis=0)



    ############################################################################################################
    ################################################ 3. Main job ###############################################
    ############################################################################################################
    for N in [25, 370]:
        if N == 25:
            rt = size_value_25
        elif N == 74:
            rt = anomalies37.loc[:,(slice(None),['p1','p10'])]
        else:
            rt = anomalies37

        figsize = 3
        fontsize = 19

        fig, axes = plt.subplots(3, 3, figsize=(figsize*5.5, figsize*4.))
        fig.subplots_adjust(hspace=0.1, wspace=0.3)

        for ii1, T in enumerate([60, 240, 600]):
            Y = rt.iloc[-T:].values
            
            ############################################################################################################
            ######################################### 3. Set parameters ################################################
            ############################################################################################################
            K = 4
            eta = config["eta"]

            # My wrapper
            pca = RPPCA(Y,
                        eta=0,
                        K=K,
                        orthogonalize_lambda=True,
                        normalization_of_factors='uncorrelated',
                        signnormalization=True)
            tmp = pca.run()
            U = tmp['residuals'][K].T @ tmp['residuals'][K]

            rp_pca = RPPCA(Y,
                           eta=eta,
                           K=K,
                           orthogonalize_lambda=True,
                           normalization_of_factors='uncorrelated',
                           signnormalization=True)

            pca_xc = PCA_XC(Y,
                            V=np.identity(N),
                            eta=eta,
                            K=K,
                            orthogonalize_lambda=True,
                            normalization_of_factors='uncorrelated',
                            signnormalization=True,
                            max_iter=config["max_iter_suboptimality"],
                            sanity_check_full_rank=True,
                            compute_objvals=True,
                            compute_gradient_norm=True)
            
            
            ############################################################################################################
            # For the case where V is an inverse of an arbitrary covariance matrix.
            ############################################################################################################
            for i in range(num_random_init):
                Lambda_init = np.random.randn(N, K)
                F_init = np.random.randn(T, K)

                ##########################################################################################
                # 1) V = I_N -- it is able to compute suboptimality.
                ##########################################################################################
                # Run ALS method to get local minimum and non-increasing sequence
                pca_xc.max_iter = config["max_iter_suboptimality"]
                pca_xc.set_problem_specifiers(V=np.identity(N))

                pca_xc.run(Lambda_init=Lambda_init, F_init=F_init, debug=True)

                phi_n = pca_xc.obj_vals

                Lambda_iterates = pca_xc.Lambda_iterates
                F_iterates = pca_xc.F_iterates

                # Compute the exact minimum value obtained from RP-PCA
                rp_pca.run()
                Lambdahat_rp, Fhat_rp = rp_pca.output['loadings'], rp_pca.output['factors']
                phi_min = pca_xc.compute_obj(Fhat_rp, Lambdahat_rp)

                # Compute distance between the exact minimzer and the iterates.
                dist_between_iterates = []
                for LL, FF in zip(Lambda_iterates, F_iterates):
                    tmp = Fhat_rp@Lambdahat_rp.T - FF@LL.T
                    dist_between_iterates.append(np.sqrt((tmp**2).sum()))
                dist_between_iterates = np.array(dist_between_iterates)

                # Plot.
                axes[ii1,0].semilogy(np.clip(phi_n - phi_min, a_min=1e-20, a_max=1e16), label='$\phi_n$')
                if ii1 == 0:
                    axes[ii1,0].set_title('(a) $\phi_n - \phi_*$', fontsize=fontsize)
                if ii1 == 2:
                    axes[ii1,0].xaxis.set_tick_params(labelsize=fontsize-2)
                else:
                    axes[ii1,0].xaxis.set_ticklabels([])
                axes[ii1,0].yaxis.set_tick_params(labelsize=fontsize-2)
                axes[ii1,0].grid(True)
                axes[ii1,0].set_ylabel(f'T={T}', fontsize=fontsize)

                axes[ii1,1].semilogy(dist_between_iterates, label='$\||\\nabla\phi_n\||$')
                if ii1 == 0:
                    axes[ii1,1].set_title('(b) $||F_n \Lambda_n^T - F_* \Lambda_*^T||_F$', fontsize=fontsize)
                if ii1 == 2:
                    axes[ii1,1].xaxis.set_tick_params(labelsize=fontsize-2)
                else:
                    axes[ii1,1].xaxis.set_ticklabels([])
                axes[ii1,1].yaxis.set_tick_params(labelsize=fontsize-2)
                axes[ii1,1].grid(True)


                ##########################################################################################
                # 2) V /= I_N: V is an inverse of an arbitrary covariance matrix.
                # It is not able to compute suboptimality.
                ##########################################################################################
                pca_xc.max_iter = config["max_iter_objective_function_value"]
                pca_xc.set_problem_specifiers(V=U)
                print(np.trace(np.linalg.inv(pca_xc.V)))

                # Run ALS method to get local minimum and non-increasing sequence
                pca_xc.run(Lambda_init=Lambda_init, F_init=F_init, debug=True)
                phi_n = pca_xc.obj_vals

                axes[ii1,2].semilogy(phi_n, label=f'$\eta={eta}$')
                if ii1 == 0:
                    axes[ii1,2].set_title(f'(c) $\phi_n$', fontsize=fontsize)
                if ii1 == 2:
                    axes[ii1,2].xaxis.set_tick_params(labelsize=fontsize-2)
                else:
                    axes[ii1,2].xaxis.set_ticklabels([])
                axes[ii1,2].yaxis.set_tick_params(labelsize=fontsize-2)
                axes[ii1,2].grid(True)


        # Set legend and x-label
        fig.supxlabel('Number of iterations', fontsize=fontsize, x=0.5, y=0.06)

        # Set title and legend
        fig.legend(
            [f'init#{i+1}' for i in range(num_random_init)],
            fontsize=fontsize-2,
            ncol=1,
            loc="upper right",
            bbox_to_anchor=(0.9, 0.88, 0, 0),
            handlelength=1
        )

        ########################################################################
        # Save figure
        ########################################################################
        filepath = os.path.join(out_directory, f"numerical_performance_N{N}.eps")
        fig.savefig(filepath, format="eps", dpi=1200, bbox_inches="tight")
        print(f"The figures are saved to: {filepath}")

        elapsed_time = int(time.time() - strt_time)
        print(f"Done. -- Elapsed time: {elapsed_time//86400} day {elapsed_time%86400//3600} hr {elapsed_time%86400%3600//60} min {elapsed_time%86400%3600%60} sec")

