import math
from macro_modules.custom_dataset import *
import gc
from tqdm import tqdm


def train_one_epoch(model, optimizer, loss_2, training_loader, scheduler, beta, count, T):
    running_loss = 0.
    last_loss = 0.
    
    for j, data in enumerate(training_loader):
        x = data
        optimizer.zero_grad()
        p = model(x, math.exp(beta*count))
        loss = loss_2(p, x[:,-T:])
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    # scheduler.step()
    count += 1
    last_loss = running_loss / (j + 1)
    # print('  batch {} loss: {}'.format(j + 1, last_loss))
    running_loss = 0.
    return last_loss, count
        
def train_model(epochs, batch_size, trainset, model, optimizer, validation_loader, loss_2, scheduler, T, beta=-0.03):
    training_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

    EPOCHS = epochs
    count = 0

    vloss_log = np.zeros(EPOCHS+3)
    vloss_log[:3] = [100, 99, 98]
    progress_bar = tqdm(range(EPOCHS), desc='EPOCH', ascii=True, miniters=int(EPOCHS/5))
    for epoch_number in progress_bar:
        model.train(True)
        avg_loss, count = train_one_epoch(model, optimizer, loss_2, training_loader, scheduler, beta, count, T)
        model.train(False)
        # if epoch % 5 == 4:
        #     torch.save(model.state_dict(), './distill/high70.pt')

        running_vloss = 0.0
        for i, vdata in enumerate(validation_loader):
            x = vdata
            p = model(x, 0)
            heatmaploss = loss_2(p, x[:,-T:])
            running_vloss += float(heatmaploss)

        avg_vloss = running_vloss / (i + 1)
        # print(f'EPOCH {epoch_number} LOSS train {avg_loss} valid {avg_vloss}')
        # scheduler.step()
        scheduler.step(avg_vloss)
        progress_bar.set_postfix({'loss': avg_loss, 'vloss': avg_vloss})
        
        vloss_log[3+epoch_number] = avg_vloss
        if np.all(abs(np.diff(vloss_log[epoch_number:epoch_number+4]))<1e-3):
            print('Early stopping at epoch', epoch_number, 'with validation loss', avg_vloss)
            break


def test_run_point(testset, model):
    model.eval()
    testloader = DataLoader(testset, batch_size=16, shuffle=False)
    len_test = len(testloader)

    prediction = []

    for data in tqdm(testloader, desc='TEST', total=len_test, ascii=True, miniters=int(len_test/5)):
        x = data
        p = model(x, 0)

        gc.disable()
        prediction.append(np.array(p.detach().to('cpu'))[...,-3:])
        gc.enable()

    prediction = np.concatenate(prediction, 0)

    return prediction


