"""
Generate Tables, Displot and QQplot.
"""
from Expe7132 import *
sns.set_theme(style='darkgrid', palette='pastel')
import pickle
import tabulate
import os

if not os.path.exists("Expe2947/"):
    os.makedirs("Expe2947/")

class Expe2947(Expe7132):
    """Tables, Displot and QQplot: For the normality of residual.
    """
    expe_no = '2947'
    tau_alpha_pair_list = (
            (10 ** (-10), 0.036),
            (10 ** (-5), 0.036),
            (10 ** (-3), 0.036),
            (10 ** (-2), 0.036),
            (10 ** (-1), 0.036),
            (10 ** (-5), 0.054),
            (10 ** (-3), 0.024),
            (10 ** (-2), 0.054),
            (10 ** (-1), 0.024),
            )
    num_iter = 200

    def dump_df(self):
        """Generate and dump df.
        """
        self.clear()
        with rich.progress.Progress() as progress:
            task = progress.add_task(f'Performing Experiment {self.expe_no}: generating data...', total = len(self.tau_alpha_pair_list) * self.num_iter)
            for tau, alpha in self.tau_alpha_pair_list:
                self.set_para(tau = tau, alpha = alpha)
                self.simulate(progress = progress, task = task, num_iter = self.num_iter)
                df = pd.DataFrame(self)

        with open(f"Expe{self.expe_no}/Expe{self.expe_no}.d", "wb") as file:
            pickle.dump(df, file)

    def load_df(self):
        """Load df.
        """
        df = pickle.load(open(f"Expe{self.expe_no}/Expe{self.expe_no}.d", "rb"))
        self.df = df

    @staticmethod
    def _describe_to_mean_pm_std(describe):
        """
        Convert a pd.DataFrame.describe() to mean +- std columns.

        Args:
            x: pd.DataFrame.describe()
        """
        def mean_pm_std(col):
            mean = f'{col["mean"]:.2g}'
            std = f'{col["std"]:.2g}'
            return f'${mean} \pm {std}$'
        return pd.Series(describe.apply(mean_pm_std, axis=0))

    def table(self, dict_column_names):
        """Generate tables.
        """
        with open(f"Expe{self.expe_no}/Expe{self.expe_no}.txt", "w") as file:
            column_names = dict_column_names.keys()
            res = pd.DataFrame(columns = column_names)
            for tau,alpha in self.tau_alpha_pair_list:
                df = self.df[(self.df[r'$\tau$'] == tau) & (self.df[r'$\lambda$'] == alpha)]
                describe = self._describe_to_mean_pm_std(df[column_names].describe())
                res = res.append(describe, ignore_index=True)
                res = res.rename(columns = dict_column_names)
            tab = tabulate.tabulate(res.T, tablefmt='latex_raw', showindex=True)
            file.write(tab)

    def qqplot(self):
        """Generate QQplots.
        """
        for i, (tau,alpha) in enumerate(self.tau_alpha_pair_list):
            df = self.df[(self.df[r'$\tau$'] == tau) & (self.df[r'$\lambda$'] == alpha)]
            fig, ax = plt.subplots(figsize=(8,6), dpi=300)
            sm.qqplot(df[r'$\zeta_1$'], line='45', ax=ax)
            ax.set_xlim(-5, 5)
            ax.set_ylim(-5, 5)
            ax.set_xlabel(ax.get_xlabel(), fontsize=20)
            ax.set_ylabel(ax.get_ylabel(), fontsize=20)
            ax.tick_params(labelsize = 20)
            fig.savefig(f"Expe{self.expe_no}/Expe{self.expe_no}_qq_{i}.pdf")
            plt.close()

    def histplot(self):
        """Generate histplots.
        """
        for i, (tau,alpha) in enumerate(self.tau_alpha_pair_list):
            df = self.df[(self.df[r'$\tau$'] == tau) & (self.df[r'$\lambda$'] == alpha)]
            fig, ax = plt.subplots(figsize=(8,6), dpi=300)
            sns.histplot(df[r'$\zeta_1$'], kde = True, ax = ax)
            ax.set_xlim(-5, 5)
            ax.set_xlabel(ax.get_xlabel(), fontsize=20)
            ax.set_ylabel(ax.get_ylabel(), fontsize=20)
            ax.tick_params(labelsize = 20)
            fig.savefig(f"Expe{self.expe_no}/Expe{self.expe_no}_ds_{i}.pdf")
            plt.close()

    def do(self, to_generate, column_names):
        """Generate data, write table, plot qq and displots.
        """
        if to_generate:
            self.dump_df()
        self.load_df()
        self.table(column_names)
        self.qqplot()
        self.histplot()

    def do_more(self, to_generate):
        dict_column_names = {
                r'$\lambda$': True,
                r'$\tau$': True,
                r'$\df/n$': True,
                r'$\hat{p}/n$': True,
                r'$\hat{n}/n$': True,
                r'$\trace[\Sigma A]$': True,
                r'$|\trace[\Sigma A]-\df/\trace[V]|$': True,
                r'$\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2$': True,
                r'$\zeta_1$': True,
        }

        for key in dict_column_names:
            if dict_column_names[key] == True:
                dict_column_names[key] = key

        self.do(to_generate, dict_column_names)

if __name__ == "__main__":

    Expe2947().do_more(to_generate = True)
