import nf
import numpy as np
from copy import deepcopy
import torch
torch.set_default_tensor_type('torch.cuda.FloatTensor')

def train_with_early_stopping(model, dataset, epochs, patience, batch_size, eval_batch_size=None, display_every=-1):
    if eval_batch_size is None:
        eval_batch_size = batch_size

    # Load data
    data = torch.Tensor(nf.load_dataset(dataset))

    # Split into train, validation and test (60%-20%-20%)
    ind1, ind2 = int(len(data) * 0.6), int(len(data) * 0.8)

    train = torch.utils.data.TensorDataset(data[:ind1])
    val = torch.utils.data.TensorDataset(data[ind1:ind2])
    test = torch.utils.data.TensorDataset(data[ind2:])

    dltrain = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
    dlval = torch.utils.data.DataLoader(val, batch_size=eval_batch_size, shuffle=False)
    dltest = torch.utils.data.DataLoader(test, batch_size=eval_batch_size, shuffle=False)

    # Train
    optimizer = torch.optim.Adam(model.parameters())

    impatient = 0
    training_val_losses = []
    times = []
    best_loss = np.inf
    best_model = deepcopy(model.state_dict())

    for epoch in range(epochs):
        # Train step
        model.train()
        for x, in dltrain:
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

            optimizer.zero_grad()

            log_prob = model.log_prob(x)
            loss = -log_prob.mean()

            loss.backward()
            optimizer.step()

            end.record()
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))

        # Validation step
        model.eval()
        loss_val = 0
        for x, in dlval:
            log_prob = model.log_prob(x)
            loss_val -= log_prob.sum()

        loss_val = loss_val.item() / len(val)
        training_val_losses.append(loss_val)

        # Early stopping
        if (best_loss - loss_val) < 1e-4:
            impatient += 1
            if loss_val < best_loss:
                best_loss = loss_val
                best_model = deepcopy(model.state_dict())
        else:
            best_loss = loss_val
            best_model = deepcopy(model.state_dict())
            impatient = 0

        if impatient >= patience:
            print(f'Breaking due to early stopping at epoch {epoch}')
            break

        # Logging
        if (epoch + 1) % display_every == 0:
            print(f"Epoch {epoch+1:4d}, loss_train = {loss:.4f}, loss_val = {loss_val:.4f}")

    # Load best model
    model.load_state_dict(best_model)
    model.eval()

    # Test model
    loss_test = 0
    for x, in dltest:
        log_prob = model.log_prob(x)
        loss_test -= log_prob.sum()
    loss_test = loss_test.item() / len(test)

    return model, loss_test, training_val_losses, times
