from math import sqrt
import numpy as np
from scipy import linalg  # XXX replace with jax to use GPU ?
from sklearn.utils import check_random_state

from activation_functions import (QuadraticActivationFunction,
                                  ReLUActivationFunction)
from theory import get_theoretical_values


def run_experiment(d=450, n=100000, m_max=3000, n_test=None, reg=0,
                   kind="ghorbani", regime="rf", activ="quad",
                   aligned=True, normalize_B=False, random_state=0):
    # setup random number generator
    rng = check_random_state(random_state)

    # build ground-truth
    if kind == "ghorbani":
        B = rng.exponential(1, size=d)
        B = np.diag(B)
    elif kind == "id":
        B = np.eye(d)
    elif kind == "increasing":
        B = np.arange(1, d + 1.)
        B = np.diag(B)
    else:
        raise NotImplementedError(kind)
    if normalize_B:
        B /= linalg.norm(B, ord="fro")

    X_train = rng.randn(n, d)
    X_test = rng.randn(n_test, d)

    # build train / test labels (don't forget to subtract trace of B!!!)
    B_trace = np.trace(B)
    get_true_y = lambda X: np.sum(np.multiply(X, X@B),
                                   axis=1) - B_trace
    y_train = get_true_y(X_train)
    y_test = get_true_y(X_test)

    # activation function and grad
    if activ == "quad":
        sigma = QuadraticActivationFunction()
    elif activ == "relu":
        sigma = ReLUActivationFunction()
    else:
        raise NotImplementedError(activ)

    # references
    err_ref = 2 * linalg.norm(B, ord="fro") ** 2
    sob2_ref = 2 * err_ref

    # covariance matrix of hidden neurons
    if aligned:
        cov = B / B_trace
    else:
        cov = np.eye(d) / d

    # grid of values of m (width of neural net)
    if m_max < 100:
        freq = 10
    elif m_max < 1000:
        freq = 50
    else:
        freq = 100
    ms = range(10, m_max + freq, freq)

    # run experiment proper
    res = []

    for i, m in enumerate(ms):
        # hidden weights matrix of NN
        W = rng.multivariate_normal(mean=np.zeros(d), cov=cov, size=m)

        # handle the differen regimes
        if regime in ["nt", "lazy_nt"]:
            if activ != "quad":
                raise NotImplementedError("%s regime with %s activation" % (
                    regime, activ))
            # Initialize output weight
            a_init = rng.randn(m) / sqrt(m)

            # Get singular-vectors of hidden weights matrix
            P1, _, _ = linalg.svd(W.T, full_matrices=False)

            # By modifying the ground-truth coef matrix B, we construct an
            # equivalent NT model without term due to init.
            # Also note that here we use a theoretical trick which allows us
            # to work with W.T@A instead of A. This allows us to write
            # everything in closed-form!
            if regime == "nt":
                B_ = B
            else:
                B_ = B - W.T@(a_init[:, None] * W)  # W^T D W
            cross = P1@P1.T@B_ / 2
            sym = cross + cross.T
            err_test = 2 * linalg.norm(B_ - sym, ord="fro") ** 2 / err_ref
            sob2 = 4 * linalg.norm(sym, ord="fro") ** 2 / sob2_ref
        elif regime in ["init", "rf", "lazy_rf"]:
            # fit output weight
            XWT_train = X_train@W.T
            Z_train = sigma(XWT_train)
            Zprime_train = sigma.grad(XWT_train)
            if regime == "init":
                a = rng.randn(m) / sqrt(m)
            else:
                if regime == "lazy_rf":
                    a_init = rng.randn(m) / sqrt(m)
                else:
                    a_init = np.zeros(m)
                f_init = Z_train@a_init
                U = Z_train.T@Z_train / n
                v = Z_train.T@(y_train - f_init) / n
                a = linalg.solve(U + reg * np.eye(m), v)
                a += a_init

            # compute nonrobustness
            WWT = W@W.T
            C = Zprime_train.T@Zprime_train / n
            C *= WWT
            sob2 = a@(C@a) / sob2_ref

            # compute generalization error
            Z_test = sigma(X_test@W.T)
            y_pred = Z_test@a
            err_test = np.mean((y_pred - y_test) ** 2) / err_ref
        elif regime in ["sgd"]:
            err_test = sob2 = np.nan

        # get theoretical values of generalization error and robustness
        err_test_theory, sob2_theory = get_theoretical_values(locals())

        res.append(dict(m=m, regime=regime, aligned=aligned,
                        err_test=err_test, err_test_theory=err_test_theory,
                        sob2=sob2, sob2_theory=sob2_theory))
        print(("aligned=%s, m=%d, regime=%s, err_test=%.2f, "
               "err_test_theory=%.2f, sob2=%.2f, sob2_theory=%.2f") % (
                   aligned, m, regime, err_test, err_test_theory, sob2,
                   sob2_theory))
    return res


if __name__ == "__main__":
    import os
    from datetime import datetime
    import pandas as pd
    import matplotlib.pyplot as plt
    import argparse
    from plotting import plot_results

    # plt.style.use("ggplot")

    parser = argparse.ArgumentParser('hello')
    parser.add_argument('--d', default=200, type=int)
    parser.add_argument('--n', default=100000, type=int)
    parser.add_argument('--n_test', default=10000, type=int)
    parser.add_argument('--m_max', default=1000, type=int)
    parser.add_argument('--activ', default='quad')
    parser.add_argument('--kind', default='ghorbani', type=str,
                        choices=["id", "ghorbani", "increasing"])
    parser.add_argument('--regime', default='nt', type=str,
                        choices=["sgd", "init", "rf", "lazy_rf", "nt",
                                 "lazy_nt"])
    parser.add_argument('--aligned', action='store_true')
    parser.add_argument('--compare_aligned', action='store_true')
    parser.add_argument('--normalize_B', action='store_true')
    parser.add_argument('--reg', default=0, type=float)
    parser.add_argument('--noplot', action='store_true')
    parser.add_argument('--noshow', action='store_true')
    parser.add_argument('--n_runs', default=1, type=int)
    args = parser.parse_args()
    args_dict = vars(args)
    print(args_dict)

    d = args.d
    n = args.n
    n_test = args.n_test
    regime = args.regime
    kind = args.kind
    aligned = args.aligned
    compare_aligned = args.compare_aligned
    activ = args.activ
    m_max = args.m_max
    reg = args.reg
    normalize_B = args.normalize_B
    n_runs = args.n_runs
    noplot = args.noplot
    noshow = args.noshow

    if compare_aligned:
        aligneds = [True, False]
    else:
        aligneds = [aligned]

    ##################################################################################
    # for aligned in aligneds:                                                       #
    #     if aligned and "nt" in regime and kind != "id":                            #
    #         raise NotImplementedError("%s regime with non id covariance" % regime) #
    ##################################################################################


    # run experiment and store results
    random_state = check_random_state(0)
    dfs = []
    for runid in range(n_runs):
        print("Run %d / %d" % (runid + 1, n_runs))
        for aligned in aligneds:
            res = run_experiment(regime=regime, kind=kind, aligned=aligned,
                                 d=d, n=n, n_test=n_test, m_max=m_max,
                                 reg=reg, activ=activ,
                                 normalize_B=normalize_B,
                                 random_state=random_state)
            df = pd.DataFrame(res)
            df["runid"] = runid
            dfs.append(df)
    df = pd.concat(dfs)
    now = datetime.today().strftime('%Y-%m-%d-%Hh')
    for path in "results/pkl", "results/png":
        if not os.path.exists(path):
            os.makedirs(path)
    df_file = ("regime=%s_aligned=%s_compare_aligned"
               "=%s_activ=%s_n=%d_d=%d_m_max=%d_kind=%s_normalize_B=%s"
               "_%s.pkl") % (regime, aligned, compare_aligned, activ, n,
                             d, m_max, kind, normalize_B, now)
    df_file = "results/pkl/%s" % df_file
    df.to_pickle(df_file)
    print(df_file)

    # generate plots
    if not noplot:
        ylim = None
        mag = 1
        legend_ncol = 1
        if "nt" in regime:
            ylim = (-.1, 2)
            mag = 1.5
            if regime == "nt" or not normalize_B:
                legend_ncol = 2
        png_file = "results/png/%s.pdf" % os.path.basename(df_file)
        plot_results(df, out_file=png_file, frameon=True,
                     legend_ncol=legend_ncol,
                     legend_loc="best", ylim=ylim, mag=mag)
        plt.savefig(png_file, dpi=200, bbox_inches="tight")
        print(png_file)
        if noshow:
            plt.close("all")
        else:
            plt.show()
