"""
Provide Base class for simulation.
"""

import numpy as np
from sklearn.linear_model import LassoLars
import pandas as pd
import statsmodels.api as sm
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from scipy import optimize
from scipy import linalg
import scipy
import scipy.integrate as integrate
import warnings
from tabulate import tabulate
import rich.pretty
import rich.progress
rich.pretty.install()
plt.rcParams.update({
    'text.usetex': True,
    'font.family': 'sans-serif',
    'font.sans-serif': ['cm10'],
    'text.latex.preamble':
    r"""
    \usepackage{cmbright}
    \usepackage{amsmath}
    \DeclareMathOperator*{\argmin}{argmin}
    \DeclareMathOperator{\diag}{diag}
    \DeclareMathOperator{\Var}{{Var}}
    \DeclareMathOperator{\trace}{tr}
    \DeclareMathOperator{\prox}{prox}
    \DeclareMathOperator{\Rem}{Rem}
    \def\E{\mathbb{E}}
    \def\R{\mathbb{R}}
    \def\eps{\varepsilon}
    \def\df{{\hat{\mathsf{df}}}}
    \def\tdf{{\tilde{\mathsf{df}}}}
    \def\hbeta{{\hat\beta}}
    \def\hpsi{{\hat\psi}}

    \def\bA{\boldsymbol{A}}
    \def\bX{\boldsymbol{X}}
    \def\by{\boldsymbol{y}}
    \def\bV{\boldsymbol{V}}
    \def\br{\boldsymbol{r}}
    \def\bpsi{\psi(\br)}
    \def\bep{\boldsymbol{\varepsilon}}
    \def\bSigma{\mathbf{\Sigma}}
    \def\bbeta{\boldsymbol{\beta}}
    \def\hbbeta{\widehat{\boldsymbol{\beta}}}
    """,
    }
    )

def noise_generator(n, noise_type='t2'):
    """
    Return a specific type random noise vector as a (n,) numpy array.
    based on your choice.
    """
    label_mapping = {
        'normal': np.random.randn,
        'cauchy': np.random.standard_cauchy,
        't2': lambda n: np.random.standard_t(df=2, size=n),
        'uniform1': lambda n: np.random.uniform(-1, 1, size=n),
        'uniform10': lambda n: np.random.uniform(-10, 10, size=n),
        'poisson': lambda n: np.random.poisson(lam=1, size=n),
    }
    return label_mapping[noise_type](n)

def huber_elastic_net(n, p, X, y, alpha, tau, lambda_star):
    """
    The Huber Elastic Net estimator,
    implemented using LassoLars.
    """

    Lambda_star = lambda_star * n ** (1/2)
    tau_t = tau * Lambda_star ** (-2)
    alpha_t = alpha * Lambda_star ** (-2)
    y_t = y * Lambda_star ** (-1)
    X_t = X * Lambda_star ** (-1)
    regr = LassoLars(
        alpha = (n / (n + p)) * alpha_t,
        fit_intercept=False,
        max_iter=int(2*10**9),
    )
    XX1 = np.hstack([X_t, np.eye(n) * alpha_t * n])
    XX = np.vstack([
        XX1,
        np.hstack([np.sqrt(n * tau_t) * np.eye(p), np.zeros((p, n))]),
        ])
    yy = np.hstack([y_t, np.zeros(p)])
    regr.fit(XX, yy)
    hbeta = regr.coef_[:p]
    htheta = regr.coef_[p:p + n]
    psi = Lambda_star * np.clip(y_t - X_t @ hbeta, -1, 1)
    S_hat = np.nonzero(hbeta)[0]
    outliers = np.nonzero(htheta)[0]
    inliers = np.nonzero(abs(y_t - X_t @ hbeta) <= 1)[0]
    D = np.diag(abs(y_t - X_t @ hbeta) <= 1)

    # Sanity checks:
    # All these quantities should be equal to the vector bold{psi}
    # in the paper. Some tests may fail due to numerical instability/
    # floating point issues.
    # In that case the norm (printed) should be small.
    try:
        assert np.allclose( Lambda_star ** (-1) * psi, y_t - XX1 @ np.hstack([hbeta, htheta]))
    except AssertionError:
        vec = Lambda_star ** (-1) *psi - (y_t - XX1 @ np.hstack([hbeta, htheta]))
        rich.print(f'leng(psi - y - Xb) large: {np.linalg.norm(vec)}')
    try:
        assert np.allclose(Lambda_star ** (-1) * psi, y_t - regr.predict(XX1))
    except AssertionError:
        vec = Lambda_star ** (-1) * psi - (y_t - regr.predict(XX1))
        rich.print(
            f'leng(psi - y - regr.predict) large: {np.linalg.norm(vec)}'
        )

    # Another sanity check. May fail if some residuals fall (up to
    # floating point precision) where psi'( ) is discontinuous.
    try:
        assert len(outliers) + len(inliers) == n
    except AssertionError:
        rich.print(
            f'len_outliers: {len(outliers)}, len_inliers: {len(inliers)}'
        )

    return hbeta, htheta, psi, S_hat, inliers, outliers, D

class Expe7132(list):
    r"""A virtual base class for simulation using Huber elastic net.
    """

    p = 1000
    gamma_list = (0.2,)
    n = 1001
    alpha = 0.036
    alpha_list = n ** (-1/2) * np.array([0.1 * 1.5 ** k for k in range(13)])
    lambda_star = 0.054
    tau = 10 ** (-10)
    tau_list = (10 ** (-10), 10 ** (-7), 10 ** (-5), 10 ** (-3), 10 **(-2), 10 **(-1.5) , 10 ** ( -1.2), 10 ** (-1) )
    noise_type = 't2'
    expe_no = None

    @staticmethod
    def beta(p):
        return np.concatenate([np.ones(100),np.zeros(p-100)]) * 10 ** (-1/2)

    @staticmethod
    def Sigma(p):
        rademacher_matrix = (-1) ** np.random.binomial(1, 0.5, size=(2*p, p))
        Sigma = rademacher_matrix.T @ rademacher_matrix / (2*p)
        Sigma_square_root = scipy.linalg.sqrtm(Sigma)
        Sigma_inverse = np.linalg.inv(Sigma)
        Sigma_inverse_square_root = np.linalg.inv(Sigma_square_root)
        return Sigma, Sigma_inverse, Sigma_square_root, Sigma_inverse_square_root

    def simu_one_datapt(self):
        """Simulate one data point in the simulation.
        """
        # Basic data generation.
        beta = self.beta(self.p)
        tX = np.random.randn(self.n, self.p)
        Sigma, Sigma_inverse, Sigma_square_root, Sigma_inverse_square_root = self.Sigma(self.p)
        X = tX @ Sigma_square_root
        epsilon = noise_generator(self.n, self.noise_type)
        y = X @ beta + epsilon
        hbeta, htheta, psi, S_hat, inliers, outliers, D = huber_elastic_net(self.n, self.p, X, y, self.alpha, self.tau, self.lambda_star)
        r = y - X @ hbeta
        p_hat = len(S_hat)
        n_hat = len(inliers)
        DXS = D @ X[:, S_hat]
        SigmaSS = Sigma[:, S_hat][S_hat, :]
        df = np.trace(
            DXS @ np.linalg.
            inv(DXS.T @ DXS + self.n * self.tau * np.eye(len(S_hat))) @ DXS.T
        ) if n_hat > 0 else 0
        trA = np.trace(
            np.linalg.inv(DXS.T @ DXS + self.n * self.tau * np.eye(len(S_hat)))
        ) if n_hat > 0 else 0
        trSigmaA = np.trace(
            SigmaSS @ np.linalg.inv(DXS.T @ DXS + self.n * self.tau * np.eye(len(S_hat)))
        ) if n_hat > 0 else 0
        trV = n_hat - df
        h = Sigma_square_root @ (hbeta - beta)

        # Intermediate ingredients to save number of calculations later.
        df_over_trV = df / trV
        Xth = X @ (hbeta - beta)
        h_norm = np.linalg.norm(h)
        Xth_norm = np.linalg.norm(Xth)
        epsilon_norm = np.linalg.norm(epsilon)
        r_norm = np.linalg.norm(r)
        psi_norm = np.linalg.norm(psi)

        # Dataset to be submitted.
        datapt = {}
        datapt.update({
            'n': self.n,
            'p': self.p,
            r'$p/n$': self.p / self.n,
            r'$\lambda$': self.alpha,
            r'$\tau$': self.tau,
            'lambda_star': self.lambda_star,
            r'$\|h\|$': h_norm,
            r'$\|\varepsilon\|$': epsilon_norm,
            r'$\trace[A]$': trA,
            r'$\trace[\Sigma A]$': trSigmaA,
            r'$\trace[V]/n$': trV / self.n,
            r'$\df/n$': df / self.n,
            r'$\df/\trace[V]$': df_over_trV,
            r'$\hat{p}/n$': p_hat / self.n,
            r'$\hat{n}/n$': n_hat / self.n,
            r'$|\trace[\Sigma A]-\df/\trace[V]|$': abs(trSigmaA - df_over_trV),
            r'$|\trace[\Sigma A]-\df/\trace[V]|/\trace[\Sigma A]$': abs(1 - df_over_trV / trSigmaA),
            r'$|\df-\trace[\Sigma A]\trace[V]|/n$': abs(df - trSigmaA * trV) / self.n,
            r'$\|\varepsilon\|^2/n$': epsilon_norm ** 2 / self.n,
            r'$\|y-X\hat\beta\|^2/n$': r_norm ** 2 / self.n,
            r'$\zeta_1$': (Xth[0] - trSigmaA * psi[0]) / h_norm,
        })

        # calculate ingredients to save calculations later.
        r_trSApsi_norm = np.linalg.norm(r + trSigmaA * psi)
        r_df_trVpsi_norm = np.linalg.norm(r + df_over_trV * psi)
        q8340 = r_trSApsi_norm ** 2 / self.n - epsilon_norm ** 2 / self.n
        q8341 = r_df_trVpsi_norm ** 2 / self.n - epsilon_norm ** 2 / self.n
        q8535 = r_norm **2 / self.n + 2 * df / self.n * epsilon_norm ** 2 / self.n - epsilon_norm ** 2 / self.n
        q7649 = (1 - 2 * df / self.n ) * r_trSApsi_norm ** 2 / self.n + trSigmaA ** 2 * psi_norm ** 2 / self.n + 2 * (df / self.n) * epsilon_norm ** 2 / self.n - epsilon_norm ** 2 / self.n
        q3863 = (1 - 2 * df / self.n ) * r_df_trVpsi_norm ** 2 / self.n + df_over_trV ** 2 * psi_norm ** 2 / self.n + 2 * (df / self.n) * r_norm ** 2 / self.n - epsilon_norm ** 2 / self.n
        q6255 = (n_hat - df) ** (-2) * ( psi_norm ** 2 * (2 * df - self.p) + np.linalg.norm( Sigma_inverse_square_root @ X.T @ psi ) ** 2)

        # More data to be submitted.
        datapt.update({
            # Append Out-of-Sample Error.
            r'$\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2$': h_norm ** 2,
            r'$\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2+\|\varepsilon\|^2/n$': h_norm ** 2 + epsilon_norm ** 2 / self.n,
            # Append Out-of-Sample Error Estimates.
            r'$\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n$': r_trSApsi_norm ** 2 / self.n,
            r'$\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n$': r_df_trVpsi_norm ** 2 / self.n,
            r'$\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n-\|\varepsilon\|^2/n$': q8340,
            r'$\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n-\|\varepsilon\|^2/n$': q8341,
            # Append Out-of-Sample Error Estimates in Prop 2.4 in [Bel20]:
            r'$(\hat n-\df)^{-2}(\|\hat{\psi}\|^2(2\df-p)+\|\Sigma^{-1/2}X^{\top}\hat\psi\|^2)$': q6255,
            r'$(\hat n-\df)^{-2}(\|\hat{\psi}\|^2(2\df-p)+\|\Sigma^{-1/2}X^{\top}\hat\psi\|^2)-\|\varepsilon\|^2/n$': q6255 - epsilon_norm ** 2 / self.n,
            # Append Out-of-Sample Error Estimates Error.
            r'$|\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2-\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n+\|\varepsilon\|^2/n|$': abs( h_norm ** 2 - q8340),
            r'$|\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2-\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n+\|\varepsilon\|^2/n|$': abs( h_norm ** 2 - q8341),
            # Append Out-of-Sample Error Estimates Error in Prop 2.4 in [Bel20]:
            r'$|\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2-(\hat n-\df)^{-2}(\|\hat{\psi}\|^2(2\df-p)+\|\Sigma^{-1/2}X^{\top}\hat\psi\|^2|$': abs(h_norm ** 2 - q6255),

            # Append In-sample error.
            r'$\|X(\hat{\beta}-\beta^*)\|^2/n$': Xth_norm ** 2 / self.n,
            r'$\|X(\hat{\beta}-\beta^*)\|^2/n+\|\varepsilon\|^2/n$': Xth_norm ** 2 / self.n + epsilon_norm ** 2 / self.n,
            # Append SURE
            r'$\|r\|^2/n+2(\df/n)\|\varepsilon\|^2/n$': q8535 + epsilon_norm ** 2 / self.n,
            r'$\|r\|^2/n+2(\df/n)\|\varepsilon\|^2/n-\|\varepsilon\|^2/n$': q8535,
            # Append In-sample error estimates.
            r'$(1-2\df/n)\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n+\trace[\Sigma A]^2\|\hat{\psi}\|^2/n+2(\df/n)\|\varepsilon\|^2/n-\|\varepsilon\|^2/n$': q7649,
            r'$(1-2\df/n)\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n+\trace[\Sigma A]^2\|\hat{\psi}\|^2/n+2(\df/n)\|\varepsilon\|^2/n$': q7649 + epsilon_norm ** 2 / self.n,
            r'$(1-2\df/n)\|\hat{r}+(\df/\trace[V])\hat{\psi}\|^2/n+(\df/\trace[V])^2\|\hat{\psi}\|^2/n+2(\df/n)\|\hat{r}\|^2/n-\|\varepsilon\|^2/n$': q3863,
            r'$(1-2\df/n)\|\hat{r}+(\df/\trace[V])\hat{\psi}\|^2/n+(\df/\trace[V])^2\|\hat{\psi}\|^2/n+2(\df/n)\|\hat{r}\|^2/n$': q3863 + epsilon_norm ** 2 / self.n,
            # Append In-sample error estimates error.
            r'$|\|X(\hat{\beta}-\beta^*)\|^2/n - (1-2\df/n)\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n-\trace[\Sigma A]^2\|\hat{\psi}\|^2/n-2(\df/n)\|\varepsilon\|^2/n+\|\varepsilon\|^2/n|$': abs(Xth_norm ** 2 / self.n - q7649),
            r'$|\|X(\hat{\beta}-\beta^*)\|^2/n - (1-2\df/n)\|\hat{r}+(\df/\trace[V])\hat{\psi}\|^2/n-(\df/\trace[V])^2\|\hat{\psi}\|^2/n-2(\df/n)\|\hat{r}\|^2/n+\|\varepsilon\|^2/n|$': abs(Xth_norm ** 2 / self.n - q3863),
        })

        return datapt

    def simulate(self, num_iter = None, progress = None, task = None):
        """Simulate multiple data points.

        Args:
            num_iter: Number of data points to be generated.
            progress: The rich.progress
            task: The rich.task
        """
        if num_iter is None:
            num_iter = self.num_iter
        for _ in range(num_iter):
            self.append(self.simu_one_datapt()
            if progress is not None and task is not None:
                progress.update(task, advance=1)


    def set_para(self, **kwargs):
        """Set parameters in the simulation.
        """
        for key in kwargs:
            setattr(self, key, kwargs[key])
        return self


    def rename(self, key = None):
        """
        Rename columns to plot clear figures.
        """
        rename_dict = {
                r'$\trace[V]/n$': r'$\trace[\bV]/n$',
                r'$|\df-\trace[\Sigma A]\trace[V]|/n$': r'$|\df-\trace[\bSigma\bA]\trace[\bV]|/n$',
                r'$\trace[\Sigma A]$': r'$\trace[\bSigma\bA]$',
                r'$\df/\trace[V]$': r'$\df/\trace[\bV]$',
                r'$|\trace[\Sigma A]-\df/\trace[V]|$': r'$|\trace[\bSigma\bA]-\df/\trace[\bV]|$',
                r'$|\trace[\Sigma A]-\df/\trace[V]|/\trace[\Sigma A]$': r'$|\trace[\bSigma\bA]-\df/\trace[\bV]|/\trace[\bSigma\bA]$',
                # out-of-sample
                r'$\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2$': r'$\|\bSigma^{1/2}(\hbbeta-\bbeta^*)\|^2$',
                # Implement Adaptive parameter tuning and its error.
                r'$\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n-\|\varepsilon\|^2/n$': r'$\|\br+\frac{\df}{\trace[\bV]}\psi(\br)\|^2/n-\|\bep\|^2/n$',
                r'$|\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2-\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n+\|\varepsilon\|^2/n|$': r'$|\|\bSigma^{1/2}(\hbbeta-\bbeta^*)\|^2-\|\br+\frac{\df}{\trace[\bV]}\psi(\br)\|^2/n+\|\bep\|^2/n|$',
                # Bel20 and its error.
                r'$(\hat n-\df)^{-2}(\|\hat{\psi}\|^2(2\df-p)+\|\Sigma^{-1/2}X^{\top}\hat\psi\|^2)$': r'$(\hat n-\df)^{-2}(\|\psi(\br)\|^2(2\df-p)+\|\bSigma^{-1/2}\bX^{\top}\psi(\br)\|^2)$',
                r'$|\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2-(\hat n-\df)^{-2}(\|\hat{\psi}\|^2(2\df-p)+\|\Sigma^{-1/2}X^{\top}\hat\psi\|^2|$': r'$|\|\bSigma^{1/2}(\hbbeta-\bbeta^*)\|^2-(\hat n-\df)^{-2}(\|\psi(\br)\|^2(2\df-p)+\|\bSigma^{-1/2}\bX^{\top}\psi(\br)\|^2|$',
                r'$\|X(\hat{\beta}-\beta^*)\|^2/n$': r'$\|\bX(\hbbeta-\bbeta^*)\|^2/n$',
                # SURE 
                r'$\|r\|^2/n+2(\df/n)\|\varepsilon\|^2/n-\|\varepsilon\|^2/n$': r'$\|\br\|^2/n+2(\df/n)\|\bep\|^2/n-\|\bep\|^2/n$',
                # Generalized risk est and its error.
                r'$(1-2\df/n)\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n+\trace[\Sigma A]^2\|\hat{\psi}\|^2/n+2(\df/n)\|\varepsilon\|^2/n-\|\varepsilon\|^2/n$': r'$(1-2\df/n)\|\br+\trace[\bSigma\bA]\psi(\br)\|^2/n+\trace[\bSigma\bA]^2\|\psi(\br)\|^2/n+2(\df/n)\|\bep\|^2/n-\|\bep\|^2/n$',
                r'$|\|X(\hat{\beta}-\beta^*)\|^2/n - (1-2\df/n)\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n-\trace[\Sigma A]^2\|\hat{\psi}\|^2/n-2(\df/n)\|\varepsilon\|^2/n+\|\varepsilon\|^2/n|$': r'$|\|\bX(\hbbeta-\bbeta^*)\|^2/n - (1-2\df/n)\|\br+\trace[\bSigma\bA]\psi(\br)\|^2/n-\trace[\bSigma\bA]^2\|\psi(\br)\|^2/n-2(\df/n)\|\bep\|^2/n+\|\bep\|^2/n|$',
                # error
                r'$|\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2-\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n+\|\varepsilon\|^2/n|$': r'$|\|\bSigma^{1/2}(\hbbeta-\bbeta^*)\|^2-\|\br+\frac{\df}{\trace[\bV]}\psi(\br)\|^2/n+\|\bep\|^2/n|$',
                r'$\|\varepsilon\|^2/n$': r'$\|\bep\|^2/n$',
                r'$\|y-X\hat\beta\|^2/n$': r'$\|\by-\bX\hbbeta\|^2/n$',
                # OS p noise.
                r'$\|\Sigma^{1/2}(\hat{\beta}-\beta^*)\|^2+\|\varepsilon\|^2/n$': r'$\|\bSigma^{1/2}(\hbbeta-\bbeta^*)\|^2+\|\bep\|^2/n$',
                r'$\|\hat{r}+\trace[\Sigma A]\hat{\psi}\|^2/n$': r'$\|\br+\trace[\bSigma\bA]\psi(\br)\|^2/n$',
                r'$\|\hat{r}+\frac{\df}{\trace[V]}\hat{\psi}\|^2/n$': r'$\|\br+\frac{\df}{\trace[\bV]}\psi(\br)\|^2/n$',

                r'$(1-2\df/n)\|\hat{r}+(\df/\trace[V])\hat{\psi}\|^2/n+(\df/\trace[V])^2\|\hat{\psi}\|^2/n+2(\df/n)\|\hat{r}\|^2/n-\|\varepsilon\|^2/n$':
                r'$(1-2\df/n)\|\br+(\df/\trace[\bV])\psi(\br)\|^2/n+(\df/\trace[\bV])^2\|\psi(\br)\|^2/n+2(\df/n)\|\br\|^2/n-\|\bep\|^2/n$',
                r'$|\|X(\hat{\beta}-\beta^*)\|^2/n - (1-2\df/n)\|\hat{r}+(\df/\trace[V])\hat{\psi}\|^2/n-(\df/\trace[V])^2\|\hat{\psi}\|^2/n-2(\df/n)\|\hat{r}\|^2/n+\|\varepsilon\|^2/n|$':
                r'$|\|\bX(\hbbeta-\bbeta^*)\|^2/n - (1-2\df/n)\|\br+(\df/\trace[\bV])\psi(\br)\|^2/n-(\df/\trace[\bV])^2\|\psi(\br)\|^2/n-2(\df/n)\|\br\|^2/n+\|\bep\|^2/n|$',
        }

        if key == None:
            return rename_dict
        elif key not in rename_dict:
            return key
        else:
            return rename_dict[key]
