#!/usr/bin/env python
# coding: utf-8
import numpy as np
import pandas as pd
from numpy.testing import assert_array_equal
import os, sys, json, time
from itertools import combinations
import matplotlib
import matplotlib.pyplot as plt
from factor_analyzer.factor_analyzer import calculate_bartlett_sphericity

sys.path.append("/workspace")
from utils import estimate_residual_by_time_series_regression


def compute_statistics_of_corr(resid):
    # Parsing
    T, N = resid.shape
    up_triang_idx = np.triu_indices(N, 1)

    # Compute
    corr = resid.corr()
    mean_abs_corr = abs(corr.values[up_triang_idx]).mean()

    return corr, mean_abs_corr


def sort_metric_srs(metric):
    metric_sorted = {}
    for n_related_factors in [2, 1, 0]:
        metric_sorted[n_related_factors] = metric.loc[n_related_factors].sort_values(ascending=True)

    return pd.concat(metric_sorted)


def argsort_sim_mat(sm):
    idx = [np.argmax(np.sum(sm, axis=1))]  # a
    for i in range(1, len(sm)):
        sm_i = sm[idx[-1]].copy()
        sm_i[idx] = -1
        idx.append(np.argmax(sm_i))  # b
        
    return np.array(idx)


class OOMFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, order=0, fformat="%1.1f", offset=True, mathText=True):
        self.oom = order
        self.fformat = fformat
        matplotlib.ticker.ScalarFormatter.__init__(self,useOffset=offset,useMathText=mathText)

    def _set_order_of_magnitude(self):
        self.orderOfMagnitude = self.oom

    def _set_format(self, vmin=None, vmax=None):
        self.format = self.fformat
        if self._useMathText:
            self.format = r'$\mathdefault{%s}$' % self.format


if __name__ == "__main__":
    strt_time = time.time()

    #############################################################################
    # Read a configuration file.
    #############################################################################
    if len(sys.argv) > 1:
        path_config_file = sys.argv[1]
    else:
        path_config_file = "/workspace/Config/resid_correlation.json"

    with open(path_config_file, "r") as f:
        config = json.load(f)
    do_sanity_check = config["do_sanity_check"]

    #############################################################################
    # Make a folder that will contain experiment results.
    #############################################################################
    if not os.path.exists(config["outpath_dir"]):
        os.makedirs(config["outpath_dir"])

    #############################################################################
    # Read FFC 6 factor returns
    #############################################################################
    tmp = []
    for key, config2 in config["ffc_monthly_file"].items():
        df = pd.read_csv(config2["path"], skiprows=config2["skiprows"], skipfooter=config2["skipfooter"], index_col=0, engine="python", dtype=float)
        df = df.loc[config["reference_period"][0] : config["reference_period"][1]]
        df /= 100
        tmp.append(df)
    F = pd.concat(tmp, axis=1)
    F.columns = F.columns.map(lambda x: x.strip())

    # Sanity check
    if do_sanity_check:
        assert_array_equal(F.applymap(lambda x: x not in [-99.99, -999]), True)

    #############################################################################
    # Read test assets
    #############################################################################
    test_asset = {}
    for key, config2 in config["test_asset"].items():
        df = pd.read_csv(config2["path"], skiprows=config2["skiprows"], nrows=config2["nrows"], index_col=0, engine="python", dtype=float, encoding='cp1252')
        df = df.loc[config["reference_period"][0] : config["reference_period"][1]]
        df /= 100
        test_asset[key] = {"rtrn": df, "related_factor": config2["related_factors"]}

    # Sanity check
    if do_sanity_check:
        for key, tst_asst in test_asset.items():
            assert_array_equal(tst_asst["rtrn"].index, F.index)
            assert_array_equal(tst_asst["rtrn"].applymap(lambda x: x not in [-99.99, -999]), True)

    #############################################################################
    # Estimate residuals of time-series regression, and
    # compute their sample correlation and Bartlett (1951)'s sphericity test statistic.
    #
    # Reference
    # - Bartlett, Maurice S., (1951), The Effect of Standardization on a chi square Approximation in Factor Analysis, Biometrika, 38, 337-344.
    #############################################################################
    corr = {}
    mean_abs_corr = {}
    bartlett_stat = {}
    bartlett_p = {}

    for tst_asst_nm, tst_asst in test_asset.items():
        r, related_factors = tst_asst["rtrn"], tst_asst["related_factor"]
        T, N = r.shape

        for factors in combinations(["SMB", "HML", "RMW", "CMA", "Mom"], 2):
            f = F[["Mkt-RF"] + list(factors)]
            n_related_factors = len(set(related_factors).intersection(factors))
            key = (n_related_factors, tst_asst_nm, factors[0], factors[1])

            # Do the main job. Compute:
            # (1) correlation coefficients of the cross-section of residuals and the average of their absolute values, and
            # (2) test statsitics and p-value of Bartlett's sphericity test
            resid = estimate_residual_by_time_series_regression(r, f, F["RF"], config["add_intercept_to_time_series_regression"], 0)
            corr[key], mean_abs_corr[key] = compute_statistics_of_corr(resid)
            bartlett_stat[key], bartlett_p[key] = calculate_bartlett_sphericity(resid)
            if do_sanity_check:
                stat_sanity_check = -np.log(np.linalg.det(corr[key])) * (T - 1 - (2 * N + 5) / 6)
                assert np.isclose(bartlett_stat[key], stat_sanity_check)

    mean_abs_corr = pd.Series(mean_abs_corr)
    bartlett_stat = pd.Series(bartlett_stat)
    bartlett_p = pd.Series(bartlett_p)

    #############################################################################
    # Plot correlation matrice using imshow()
    #############################################################################
    figsize = 12

    fig, axes = plt.subplots(7, 10, figsize=(figsize, figsize))
    fig.suptitle(
        "Absolute values of correlation coefficients of OLS residuals",
        fontsize=15,
        x=0.5,
        y=0.80,
        verticalalignment="top",
    )
    fig.subplots_adjust(hspace=-0.7, wspace=0.1)

    for i, (tst_asst_nm, tst_asst) in enumerate(test_asset.items()):
        # Sort by magnitude of the Bartlett's statistic.
        idx = tst_asst_nm == mean_abs_corr.index.get_level_values(level=1)
        bartlett_stat_ = sort_metric_srs(bartlett_stat[idx])
        idx = bartlett_stat_.index
        abs_corr = {key: corr[key].abs() for key in idx}

        # Sort each of the correlation matrices for better readability
        tmp = {}
        for key, val in abs_corr.items():
            idx = argsort_sim_mat(val.values)
            tmp[key] = val.iloc[idx, idx]
        abs_corr = tmp

        # Draw figures
        for j, (key, abs_corr_) in enumerate(abs_corr.items()):
            n_related_factors, tst_asst_nm2, factor1, factor2 = key
            if do_sanity_check:
                assert tst_asst_nm == tst_asst_nm2

            ax = axes[i, j]
            ax.imshow(abs_corr_, cmap="gray", vmin=0, vmax=1)
            if j == 0:
                ax.set_ylabel(tst_asst_nm.replace("5x5 ", ""), fontsize=13)
            ax.set_xticks([])
            ax.set_yticks([])

    # Save the figure
    filepath = os.path.join(config["outpath_dir"], f"residual_correlation_matrices_in_image.eps")
    if False:
        fig.savefig(filepath, format="eps", dpi=None, bbox_inches="tight")

    #############################################################################
    # Plot correlation matrices for Size-B/M portfolios using imshow()
    #############################################################################
    figsize = 12
    fontsize_axtitle = 23.5

    fig, axes = plt.subplots(2, 5, figsize=(figsize, figsize))
    fig.subplots_adjust(hspace=-0.66, wspace=0.1)

    tst_asst_nm = "5x5 Size-B/M"
    tst_asst = test_asset[tst_asst_nm]

    # Sort by magnitude of the Bartlett's statistic.
    idx = tst_asst_nm == mean_abs_corr.index.get_level_values(level=1)
    bartlett_stat_ = sort_metric_srs(bartlett_stat[idx])
    idx = bartlett_stat_.index
    abs_corr = {key: corr[key].abs() for key in idx}

    # Sort each of the correlation matrices for better readability
    tmp = {}
    for key, val in abs_corr.items():
        idx = argsort_sim_mat(val.values)
        tmp[key] = val.iloc[idx, idx]
    abs_corr = tmp

    # Draw figures
    for i, (key, abs_corr_) in enumerate(abs_corr.items()):
        n_related_factors, tst_asst_nm2, factor1, factor2 = key
        if do_sanity_check:
            assert tst_asst_nm == tst_asst_nm2

        related_factors = config["test_asset"][tst_asst_nm]["related_factors"]
        factor1 += "*" if factor1 in related_factors else ""
        factor2 += "*" if factor2 in related_factors else ""

        ax = axes[int(np.floor(i / 5)), i % 5]
        tmp = ax.imshow(abs_corr_, cmap="gray", vmin=0, vmax=1)
        ax.set_title(f"{factor1}-{factor2}", fontsize=fontsize_axtitle)
        ax.set_xticks([])
        ax.set_yticks([])
    
    if False:
        cax = plt.axes([0.92, 0.326, 0.02, 0.3395])
        cbar = fig.colorbar(tmp, orientation='vertical', cax=cax)
    else:
        cax = plt.axes([0.1226, 0.28, 0.7782, 0.02])
        cbar = fig.colorbar(tmp, orientation='horizontal', cax=cax)
    cbar.ax.tick_params(labelsize=fontsize_axtitle)

    # Save the figure
    filepath = os.path.join(config["outpath_dir"], f"residual_correlation_matrices_in_image_{tst_asst_nm.replace('5x5 ','').replace('/','_')}.pdf")
    fig.savefig(filepath, dpi=None, bbox_inches="tight")

    #############################################################################
    # Plot bar graphs
    #############################################################################
    bar_colors = ["C0", "C1"]
    bar_width = 0.35
    fontsize = {"tick": 15, "ylabel": 15, "legend": 15}

    for i, (tst_asst_nm, tst_asst) in enumerate(test_asset.items()):
        ########################################################################
        # Sort by magnitude of the Bartlett's statistic.
        ########################################################################
        idx = tst_asst_nm == mean_abs_corr.index.get_level_values(level=1)
        bartlett_stat_ = sort_metric_srs(bartlett_stat[idx])
        idx = bartlett_stat_.index
        mean_abs_corr_ = mean_abs_corr.loc[idx]
        bartlett_p_ = bartlett_p.loc[idx]

        def idx4readability(idx, related_factors):
            tmp = []
            for x in idx:
                a = x[2] + "*" if x[2] in related_factors else x[2]
                b = x[3] + "*" if x[3] in related_factors else x[3]
                tmp.append(f"{a}-{b}")

            return pd.Index(tmp)

        df = pd.concat({"Avg. abs. correlation": mean_abs_corr_, "Bartlett's stat.": bartlett_stat_}, axis=1)
        df.index = idx4readability(df.index, config["test_asset"][tst_asst_nm]["related_factors"])

        ########################################################################
        # Draw figure
        ########################################################################
        fig, ax = plt.subplots(figsize=(6, 3))
        ax.grid(axis="y", zorder=1)

        # add first line to plot
        lns1 = ax.bar(df.index, df.iloc[:, 0], color=bar_colors[0], width=-bar_width, align="edge", zorder=3)

        # add x-axis label
        ax.set_xticklabels(df.index, rotation=-45, ha="left", fontsize=fontsize["tick"])
        ax.tick_params(axis="both", which="major", labelsize=fontsize["tick"])

        # add y-axis label
        ax.set_ylabel(df.columns[0], color="black", fontsize=fontsize["ylabel"])

        # define second y-axis that shares x-axis with current plot
        ax2 = ax.twinx()

        # add second line to plot
        lns2 = ax2.bar(df.index, df.iloc[:, 1], color=bar_colors[1], width=bar_width, align="edge", hatch="//")
        ax2.tick_params(axis="both", which="major", labelsize=fontsize["tick"])
        ax2.yaxis.set_major_formatter(OOMFormatter(3, "%0.1f"))

        # add second y-axis label
        ax2.set_ylabel(df.columns[1], color="black", fontsize=fontsize["ylabel"])

        # Set title and legend
        fig.legend(
            df.columns,
            fontsize=fontsize["legend"],
            ncol=1,
            loc="upper center",
            bbox_to_anchor=(-0.147, 1.055, 1, 0)
        )

        ########################################################################
        # Save figure
        ########################################################################
        filepath = os.path.join(config["outpath_dir"], f"residual_correlation_statistics_{tst_asst_nm.replace('5x5 ','').replace('/','_')}.eps")
        fig.savefig(filepath, format="eps", dpi=None, bbox_inches="tight")


    #############################################################################
    elapsed_time = time.time() - strt_time
    print(f"Done! -- Elapsed time: {elapsed_time/60:.0f} min {elapsed_time%60:.0f} sec")

