import torch
import torch.nn as nn
from torch.distributions import Dirichlet
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
from sklearn.datasets import make_moons
from models import MLPMAPClassifier, MLPMAPFVIClassifier
from models import MLPEnsembleClassifier, MLPEnsembleFVIClassifier
from models import MLPDropoutClassifier, MLPDropoutFVIClassifier


if __name__ == '__main__':
    filename = "./figures/two_moons_mlp_basic_prior_1_1"

    torch.manual_seed(123)
    n_samples = 100
    n_features = 2
    n_classes = 2
    X, y = make_moons(n_samples=n_samples, noise=.2, random_state=456)

    x_min, x_max = X[:, 0].min() - 3, X[:, 0].max() + 3
    y_min, y_max = X[:, 1].min() - 3, X[:, 1].max() + 3
    xx, yy = np.meshgrid(np.arange(x_min, x_max, .05),
                         np.arange(y_min, y_max, .05))

    X_test = np.c_[xx.ravel(), yy.ravel()]

    n_models = 10
    max_precision = X.shape[0]
    args = (n_classes, n_features, max_precision)
    kwargs = dict(hidden_dims=(25, 25), bias=True, activation=nn.ReLU)

    model = MLPMAPClassifier(*args, **kwargs)
    # model = MLPMAPFVIClassifier(*args, **kwargs)

    # model = MLPDropoutClassifier(*args, dropout=.2, **kwargs)
    # model = MLPDropoutFVIClassifier(*args, dropout=.2, **kwargs)

    # model = MLPEnsembleClassifier(*args, n_models=n_models, **kwargs)
    # model = MLPEnsembleFVIClassifier(*args, n_models=n_models, **kwargs)

    prior = (1., 1.)

    optimizer = torch.optim.Adam(model.parameters(), lr=.005)
    epochs = 1000
    nll_trace = []
    fkl_trace = []
    for i in range(epochs):
        print("i: {}/{}".format(i+1, epochs))
        optimizer.zero_grad()
        raw_logits_data = model(torch.from_numpy(X).float())
        nll = model.nll_loss(raw_logits_data, torch.from_numpy(y))
        nll_trace.append(nll.item())
        loss = nll
        if callable(getattr(model, 'fkl_loss', None)):
            raw_logits_samples = model(torch.from_numpy(X_test).float())
            fkl = model.fkl_loss(raw_logits_samples, prior_param=prior)
            loss += fkl
            fkl_trace.append(fkl.item())
        loss.backward()
        optimizer.step()

    x_min, x_max = X[:, 0].min() - 2, X[:, 0].max() + 2
    y_min, y_max = X[:, 1].min() - 2, X[:, 1].max() + 2
    xx, yy = np.meshgrid(np.arange(x_min, x_max, .05),
                         np.arange(y_min, y_max, .05))

    X_test = np.c_[xx.ravel(), yy.ravel()]

    # plot stuff
    cm = plt.cm.RdBu
    cm_bright = ListedColormap(['#FF0000', '#0000FF'])
    plt.figure()
    ax = plt.subplot(111)
    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    ax.set_xticks(())
    ax.set_yticks(())

    with torch.no_grad():
        alpha = model.predict(torch.from_numpy(X_test).float())
    posterior = Dirichlet(alpha)
    z = posterior.mean

    ax.contourf(xx, yy, z[:, 1].reshape(xx.shape), cmap=cm, alpha=.8)

    # Plot the training points
    ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cm_bright,
               edgecolors='k')

    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    ax1.plot(nll_trace, color='tab:blue')
    ax2 = ax1.twinx()
    ax2.plot(fkl_trace, color='tab:orange')
    plt.show()
