import torch
import numpy
import matplotlib.pyplot as plt


def multi_plot(data):
    pairs = [(i, j) for i in range(0, len(data)) for j in range(i + 1, len(data))]
    fig, axs = plt.subplots(1, len(pairs), figsize=(18, 5))
    for i, (x, y) in enumerate(pairs):
        axs[i].scatter(data[x], data[y], alpha=1, s=1)
        axs[i].set_xlabel(f"{x}")
        axs[i].set_ylabel(f"{y}")
        axs[i].set_title(f"{x} vs {y}")

    plt.tight_layout()
    plt.show()


def get_predictions(model, data):
    traindata = torch.from_numpy(numpy.copy(data)).float()
    device = torch.device("mps")
    yhat, mu, logvar = model(torch.tensor(traindata).to(device))
    model.eval()
    return (
        yhat.cpu().detach().numpy(),
        mu.cpu().detach().numpy(),
        logvar.cpu().detach().numpy(),
    )
