"""
Generate Heatmap simulations.
"""

from Expe7132 import *
import pickle
import os
from collections import defaultdict

if not os.path.exists("Expe9591/"):
    os.makedirs("Expe9591/")
if not os.path.exists("Expe9591/df1/"):
    os.makedirs("Expe9591/df1/")
if not os.path.exists("Expe9591/df2/"):
    os.makedirs("Expe9591/df2/")

class Expe9591(Expe7132):
        """ Provide Heatmaps.
        """

        expe_no='9591'
        num_iter=100

        def dump_data(self):
            """Generate and dump data.
            """

            with rich.progress.Progress() as progress:
                task = progress.add_task(f'Performing Experiment {self.expe_no}: generating data...', total = (len(self.tau_list) * len(self.gamma_list) + len(self.tau_list) * len(self.alpha_list)) * self.num_iter)

                self.clear()
                for i, tau in enumerate(self.tau_list):
                    for j, gamma in enumerate(self.gamma_list):
                        self.set_para(tau=tau, p=int(gamma * self.n), alpha = Expe7132.alpha)
                        self.simulate(progress = progress, task = task)
                df1 = pd.DataFrame(self)

                self.clear()
                for i, tau in enumerate(self.tau_list):
                    for j, alpha in enumerate(self.alpha_list):
                        self.set_para(tau=tau, p=Expe7132.p, alpha = alpha)
                        self.simulate(progress = progress, task = task)
                df2 = pd.DataFrame(self)

            with rich.progress.Progress() as progress:
                task = progress.add_task(f'Performing Experiment {self.expe_no}: pickling data...', total = 2)

                with open(f"Expe{self.expe_no}/Expe{self.expe_no}_df1.d", "wb") as file:
                    pickle.dump(df1, file)
                    progress.update(task, advance=1)

                with open(f"Expe{self.expe_no}/Expe{self.expe_no}_df2.d", "wb") as file:
                    pickle.dump(df2, file)
                    progress.update(task, advance=1)

        def load_data(self):
            """Load data.
            """
            df1 = pickle.load(open(f"Expe{self.expe_no}/Expe{self.expe_no}_df1.d", "rb"))
            df2 = pickle.load(open(f"Expe{self.expe_no}/Expe{self.expe_no}_df2.d", "rb"))
            df1 = df1.rename(columns = self.rename())
            df2 = df2.rename(columns = self.rename())
            return df1, df2

        def plot(self, df1, df2, dict_quantities):
            """Plots the boxplots based on pd.DataFrame.
            """

            with rich.progress.Progress() as progress:
                task = progress.add_task(f'Performing Experiment {self.expe_no}: plotting figures...', total = len(dict_quantities))

                for key in dict_quantities:
                    column = self.rename(dict_quantities[key])

                    fig, ax = matplotlib.pyplot.subplots(figsize = (8, 6))
                    # A helper function that formats the numbers.
                    fn = lambda x:[ f"{float(item):.2g}" for item in x]
                    df=df1[[r"$\tau$", r"$p/n$", column]].groupby([r"$\tau$", r"$p/n$"], as_index=False).mean().pivot(index=r"$\tau$", columns=r"$p/n$", values=column)
                    sns.heatmap(df, ax=ax, cmap = 'terrain', annot = True)
                    ax.set_title(column)
                    ax.set_xticklabels(fn([x.get_text() for x in ax.get_xticklabels()]))
                    ax.set_yticklabels(fn([x.get_text() for x in ax.get_yticklabels()]))
                    fig.savefig(f'Expe{self.expe_no}/df1/Expe{self.expe_no}_{key}.pdf', dpi=300)
                    plt.close()
                    fig, ax = matplotlib.pyplot.subplots(figsize = (11, 8.5))
                    plt.tight_layout(pad=7)

                    df=df2[[r"$\tau$", r"$\lambda$", column]].groupby([r"$\tau$", r"$\lambda$"], as_index=False).mean().pivot(index=r"$\tau$", columns=r"$\lambda$", values=column)

                    dict_vmin_vmax = defaultdict(lambda: (None, None),
                            {
                            "os" :(0,12),
                            "os_1" :(0,12),
                            "os_1_err" :(0,12),
                            "os_2" :(0,12),
                            "os_2_err" :(0,12),
                            'df': (0,1),
                            'trV': (0,1),
                            'nhat': (0,1),
                            'phat': (0,1),
                            'df_err': (0,1),
                            'trSigmaA': (0,6),
                            'trSigmaA_err': (0,6),
                            })
                    dict_fmt = defaultdict(lambda: ".2g",
                            {
                                "os_1_err" : ".1f",
                            })
                    dict_cmap = defaultdict(lambda: "terrain",
                            {
                                'df': "Spectral",
                                'trV': "Spectral",
                                'nhat': "Spectral",
                                'phat': "Spectral",
                                'df_err': "Spectral",
                                'trSigmaA': "nipy_spectral",
                                'trSigmaA_err': "nipy_spectral",
                            })
                    sns.heatmap(df, ax=ax,
                            cmap = dict_cmap[key], vmin = dict_vmin_vmax[key][0], vmax = dict_vmin_vmax[key][1],
                            annot = False, fmt = dict_fmt[key], annot_kws={"fontsize":20},
                            )
                    ax.set_title(column, fontsize = 22)
                    ax.set_xticklabels(fn([x.get_text() for x in ax.get_xticklabels()]), fontsize = 30)
                    ax.set_yticklabels(fn([x.get_text() for x in ax.get_yticklabels()]), fontsize = 30)
                    ax.set_xlabel(r"$\lambda$", fontsize=30)
                    ax.set_ylabel(ax.get_ylabel(), fontsize=30)
                    cbar = ax.collections[0].colorbar
                    cbar.ax.tick_params(labelsize=20)
                    ax.tick_params(labelsize = 20)
                    fig.savefig(f'Expe{self.expe_no}/df2/Expe{self.expe_no}_{key}.pdf', dpi=300)

                    plt.close()

                    progress.update(task, advance=1)

        def do(self, to_generate, dict_quantities = None):

            if to_generate:
                self.dump_data()
            df1, df2 = self.load_data()
            self.plot(df1, df2, dict_quantities)

        def do_more(self, to_generate):
            """ Plot multiple plots.

            Args:
                to_generate: bool. Generate data if True. Load dumpped data if False.
            """

            dict_quantities = {
                'df': r'$\df/n$',
                'trV': r'$\trace[V]/n$',
                'nhat': r'$\hat{n}/n$',
                'phat': r'$\hat{p}/n$',
                'df_err': r'$|\df-\trace[\Sigma A]\trace[V]|/n$',
                'trSigmaA': r'$\trace[\Sigma A]$',
                'trSigmaA_err': r'$|\trace[\Sigma A]-\df/\trace[V]|$',
                'df_trV': r'$\df/\trace[V]$',
                "rel_A": r'$|\trace[\Sigma A]-\df/\trace[V]|/\trace[\Sigma A]$',
                "os": r'$\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2$',
                "os_1": r'$\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n-\|\varepsilon\|^2/n$',
                "os_2": r'$(\hat n-\df)^{-2}(\|\hat{\psi}\|^2(2\df-p)+\|\Sigma^{-1/2}X^{\top}\hat\psi\|^2)$',
                "os_p_noise": r'$\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2+\|\varepsilon\|^2/n$',
                "os_p_noise_1": r'$\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n$',
                "os_p_noise_2": r'$\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n$',
                "os_1_err": r'$|\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2-\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n+\|\varepsilon\|^2/n|$',
                "os_2_err": r'$|\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2-(\hat n-\df)^{-2}(\|\hat{\psi}\|^2(2\df-p)+\|\Sigma^{-1/2}X^{\top}\hat\psi\|^2|$',
                "Xh": r'$\|X(\hat{\beta}-\beta^*)\|^2/n$',
                "Xh_1": r'$(1-2\df/n)\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n+\trace[\Sigma A]^2\|\hat{\psi}\|^2/n+2(\df/n)\|\varepsilon\|^2/n-\|\varepsilon\|^2/n$',
                "Xh_1_err": r'$|\|X(\hat{\beta}-\beta^*)\|^2/n - (1-2\df/n)\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n-\trace[\Sigma A]^2\|\hat{\psi}\|^2/n-2(\df/n)\|\varepsilon\|^2/n+\|\varepsilon\|^2/n|$',
                "Xh_2": r'$\|r\|^2/n+2(\df/n)\|\varepsilon\|^2/n-\|\varepsilon\|^2/n$',
                "Xh_3": r'$(1-2\df/n)\|\hat{r}+(\df/\trace[V])\hat{\psi}\|^2/n+(\df/\trace[V])^2\|\hat{\psi}\|^2/n+2(\df/n)\|\hat{r}\|^2/n-\|\varepsilon\|^2/n$',
                "Xh_3_err": r'$|\|X(\hat{\beta}-\beta^*)\|^2/n - (1-2\df/n)\|\hat{r}+(\df/\trace[V])\hat{\psi}\|^2/n-(\df/\trace[V])^2\|\hat{\psi}\|^2/n-2(\df/n)\|\hat{r}\|^2/n+\|\varepsilon\|^2/n|$',
                "noise":r'$\|\varepsilon\|^2/n$',
            }

            self.do(to_generate = to_generate, dict_quantities = dict_quantities)

if __name__ == "__main__":

    Expe9591().do_more(to_generate = False)
