import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from DL.model import *
from DL.gendata import *
from time import time
from sklearn.mixture import GaussianMixture
from sklearn.linear_model import Lasso, OrthogonalMatchingPursuit


def plot_gallery(images, title=None, n_col=4, n_row=4):
    """Draw several components on one figure."""
    plt.figure(figsize=(2. * n_col, 2.26 * n_row))
    if title:
        plt.suptitle(title, size=16)
    for i, comp in enumerate(images):
        plt.subplot(n_row, n_col, i + 1)
        plt.plot(np.array(range(len(comp))), comp)
        plt.xticks(())
    plt.subplots_adjust(0.01, 0.05, 0.99, 0.93, 0.04, 0.)


def visualize_pred_perfomance(model, data_loader, opt='default'):
    """Plot the prediction performance on the dataset.

        Parameters
        ----------
        model : FC_no_bias object
            Two layer NN as we defined.
        data_loader : data_loader object
            A data loader.
        opt : string
            "default": plot pred vs true value.
            otherwise: plot pred vs first feature.
    """
    output = np.array([])
    y = np.array([])
    with torch.no_grad():
        for data, target in data_loader:
            output = np.concatenate((output, model(data).detach().numpy().reshape(-1)))
            if opt == 'default':
                y = np.concatenate((y, target.numpy().reshape(-1)))
            else:
                y = np.concatenate((y, data[:,1].numpy().reshape(-1)))
    plt.scatter(y, output)
    plt.xlabel('True value') if opt == 'default' else plt.xlabel('X_1')
    plt.ylabel('Prediction')
    plt.show()


def variable_selection(model, indices=None, plot=True, exclude=None, raw=False):
    """Perform variable selection on the trained model.

        Parameters
        ----------
        model : FC_no_bias object or tree object or LASSO/OMP
            Two layer NN as we defined, or sklearn.ensemble class, or sklearn.Lasso/OMP class.
        indices : int or array-like
            Indices of effective features. If it's an integer `k' then the first `k' features.
        plot : bool
            Whether plot variable importance figure.
        exclude : int or array-like
            Indices of excluded features. If it's an integer `k' then only the  `k'-th feature.
        raw: bool
            Return raw variable importance or clustering result.
    """
    if type(model) == FC_no_bias:
        what = list(model.parameters())[0].detach().numpy()
        ahat = list(model.parameters())[1].detach().numpy()
        what = (what.T * np.abs(ahat)).T
        zeros, tot = what[np.isclose(what, 0, atol=1e-3)].shape[0], what.shape
        tot = tot[0]*tot[1]
        print(f'{zeros} out of {tot} ({(zeros/tot*100):.2f}%) parameters are zero.')
        X = np.abs(what).mean(axis=0).reshape(-1, 1)
    elif type(model) == Lasso or type(model) == OrthogonalMatchingPursuit:
        X = np.abs(model.coef_.reshape(-1, 1))
    else:
        X = model.feature_importances_.reshape(-1, 1)

    if exclude is not None:
        X = np.delete(X, exclude, axis=0)

    if raw:
        return X.reshape(-1)

    y_pred = GaussianMixture(n_components=2).fit_predict(X)
    s1, s2 = X[y_pred == 0].mean(), X[y_pred == 1].mean()
    flag = 0 if s1 > s2 else 1
    if plot:
        labels = {0: 'Selected', 1: 'Unselected'} if s1 > s2 else {1: 'Selected', 0: 'Unselected'}
        sns.scatterplot(x=range(len(X)), y=X.reshape(-1), hue=np.vectorize(labels.get)(y_pred))
        plt.xlabel('Features')
        plt.ylabel('Importance')
        plt.show()
    if indices is not None:
        if np.issubdtype(type(indices), np.integer):
            indices = np.arange(indices)
        return sum(y_pred[indices] == flag), sum(np.delete(y_pred, indices) == flag)
    else:
        return 1-y_pred if flag == 0 else y_pred


def paste(x, y, decimal=2):
    return f'{np.round(x, decimals=decimal)}({np.round(y, decimals=decimal)})'


def run(train_loader, test_loader, model, lam, lr=0.001, num_epochs=100, verbose=True):
    """Train the neural network and evaluate on the test dataset.

        Parameters
        ----------
        train_loader : int
            Number of sample size.
        test_loader : int
            Number of total features.
        model : FC_no_bias object
            Two layer NN as we defined.
        lam : float
            Penalty parameter.
        lr : float
            Learning rate.
        num_epochs : int
            Number of epoch size.
        verbose : bool
            True to print training details.

        Returns
        -------
        train_loss : Scalar tensor
            Train loss.
        test_loss : Scalar tensor
            Test loss.
        model : FC_no_bias object
            Trained model.
    """
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0, weight_decay=0)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.5)
    for epoch in range(1, num_epochs + 1):
        train_loss = train(model, train_loader, criterion, optimizer, lam)
        test_loss = test(model, test_loader, criterion)
        scheduler.step()
        if (epoch % 10 == 0) & verbose:
            print('Train({})[{:.0f}%]: Loss: {:.4f}; Test error:{:.4f}'.format(
                epoch, 100. * epoch / num_epochs, train_loss, test_loss))

    # note that the test loss has no penalty so it can be smaller than the training loss
    return train_loss, test_loss, model


if __name__ == "__main__":
    # experiment setting
    # n: sample size
    # p: number of features, or input size
    # eff_p: number of effective features
    # r: hidden layer size
    # sigma: noise level, \epsilon ~ N(0, \sigma^2)
    # Model: y = \sum a f(w^T x) + \eps
    torch.set_default_tensor_type(torch.DoubleTensor)
    torch.set_default_dtype(torch.float64)
    np.set_printoptions(precision=2)
    print('pass')
    exit()

    n, p, r = 1000, 100, 16
    eff_p = np.floor(np.sqrt(p)).astype(int)
    sigma = 1
    batch_size = int(n / 50)
    n_test = 2000

    # init ground truth parameters
    w, a = gen_paras(p, r, eff_p)

    # Experiment 1
    lams = np.array([0, 0.03, 0.06, 0.1])
    sigmas = np.array([0, 0.5, 1, 5])
    nrep = len(sigmas)
    nlam = len(lams)
    test_err = np.zeros((nrep, nlam))
    variation = np.zeros((nrep, nlam))
    snr = np.zeros(nrep)  # Signal to noise ratio
    idx = np.zeros(nrep)  # index of model which gives best test error for each sigma
    models = np.empty((nrep, nlam), dtype=object)
    right = np.zeros(nrep)  # Number of correctly selected features / true positive
    wrong = np.zeros(nrep)  # Number of false selected features / false positive

    t = time()
    for i in range(nrep):
        train_loader, test_loader, snr[i] = gen_data_loader(n, n_test, p, sigmas[i], w, a, batch_size)
        flag = np.inf
        for j in range(nlam):
            train_err, test_err[i, j], models[i, j] = run(train_loader, test_loader, FC_no_bias(p, r, 1), lam=lams[j])
            variation[i, j] = penalty(models[i, j], 1)
            if test_err[i, j] < flag:
                idx[i], flag = j, test_err[i, j]
            print(f"Sigma {sigmas[i]} Lam {lams[j]} cost time: %4f " % (time() - t))

        # evaluate the goodness of each model
        model = models[i, idx[i]]
        right[i], wrong[i] = variable_selection(model, eff_p, False)

    plt.plot(range(nrep), right, color='r', label='right selected')
    plt.plot(range(nrep), wrong, color='b', label='wrong selected')
    plt.xticks(range(nrep), np.round(snr, 2))
    plt.legend()
