import torch
import time
from DSDGP.deep_gp import DSDGP
from EDGP.deep_gp import EDDGP
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.preprocessing import StandardScaler
from assign_data import *


def make_dgp(in_dims, n_layers, num_inducing, num_prio, hidden_dims, name='edgp'):
    if name == 'edgp':
        deep_gp = EDDGP
    elif name == 'dsdgp':
        deep_gp = DSDGP
    else:
        raise ValueError

    model = deep_gp(in_dims, n_layers, num_inducing, num_prio, hidden_dims)

    return model


def split_array_into_blocks(arr, M):
    N = len(arr)
    if N == 0:
        return []
    split_points = np.arange(M, N, M)
    return np.split(arr, split_points)


def training_step(model, train_loader, optimizer):
    elbos = []
    durations = 0
    model.train()
    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.float().cuda(), y_batch.float().cuda()
        start_time = time.time()
        objective = model.fit(x_batch, y_batch)
        durations += time.time() - start_time
        model.update(objective, optimizer)
        elbos.append(-objective.item())
    return np.mean(elbos), durations


def evaluation_step(model, eval_loader):
    model.eval()
    mses, maes = [], []
    for x_batch, y_batch in eval_loader:
        x_batch = x_batch.float().cuda()
        y_batch = y_batch.detach().cpu().numpy()
        m, v = model(x_batch)
        # m, v = m.mean(0), v.mean(0)
        m = m.detach().cpu().numpy()
        y_batch, m = y_batch[:, -1], m[:, -1]
        mse = mean_squared_error(y_batch, m)
        mae = mean_absolute_error(y_batch, m)
        mses.append(mse)
        maes.append(mae)
    return np.mean(mses), np.mean(maes)


def run_training():
    seq = 16
    in_dims = 6 * seq
    n_layers = 4
    num_inducing = 256
    num_samples = 20
    hidden_dims = 64
    batch_size = 1024

    name = 'dsdgp' # or 'edgp'
    epochs = 40

    dgp = make_dgp(in_dims, n_layers, num_inducing, num_samples, hidden_dims, name=name)
    dgp = dgp.cuda()
    dgp.float()

    optimizer = torch.optim.AdamW(dgp.parameters(), lr=1e-3)

    Z, U, Z_eval, U_eval = read_data(data_name='ETTh2_6i1o.csv', io='6i1o', seq_len=seq)
    loaders = construct_loader(Z, U, Z_eval, U_eval, batch_size=batch_size, seq_len=seq)
    train_data = [frame for frame in iter(loaders[0])]
    test_data = [frame for frame in iter(loaders[1])]

    bst_mse = 100.0

    for epoch in range(epochs):
        elbo, duration = training_step(dgp, train_data, optimizer)
        mse, mae = evaluation_step(dgp, test_data)
        if mse <= bst_mse:
            bst_mse = mse
            print(
                f"Epoch: {epoch}, Model: {name} {n_layers}, ELBO: {elbo}, MSE: {mse}, MAE: {mae}, Duration: {duration} seconds")
        # print(
        #     f"Epoch: {epoch}, Model: {name} {n_layers}, ELBO: {elbo}, MSE: {mse}, MAE: {mae}, Duration: {duration} seconds")


if __name__ == '__main__':
    run_training()
