from tqdm import tqdm
import torch
import copy 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, epochs, optimizer, 
                        criterion,
                        scheduler,
                        train_dl, 
                        val_dl):
    # calculating accuracy
    best_acc = 0
    # early stopping params
    n_epochs_no_improve = 0
    max_epochs_no_improve = 12

    # plotting progress bar
    pbar = tqdm(total=epochs, position=0, leave=True)
    pbar.set_description("epoch 1/{} | acc: train={:.2f}% val={:.2f}%"
                            .format(epochs, 0, 0))

    for epoch in range(epochs):    

        # fit and evaluate
        train_acc = fit(model, train_dl, optimizer, criterion)
        val_acc = eval(model, val_dl, criterion)
        
        # update pbar
        pbar.update(1)
        pbar.set_description("epoch {}/{} | acc: train={:.2f}% val={:.2f}%"
                                .format(epoch + 1, epochs, 100 * train_acc, 100 * val_acc))
        
        # Checkpoint for best model and early stopping
        if(val_acc > best_acc):
            best_model = copy.deepcopy(model)
            best_acc = val_acc
            n_epochs_no_improve = 0
        else:
            n_epochs_no_improve += 1
            if(n_epochs_no_improve > max_epochs_no_improve):
                break

        scheduler.step()
    
    pbar.close()
    return best_model   

def fit(model, train_dl, optimizer, criterion):
    correct = 0
    total = 0

    model.train()
    for data, labels in train_dl:

        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, labels) 
        loss.backward()
        optimizer.step() 

        if(len(output.shape) == 1):
            preds = torch.tensor([1 if output[i] >= 0.5 else 0 for i in range(output.shape[0])])
        else:
            preds = torch.argmax(output, dim=-1)

        correct += preds.eq(labels).sum().item()
        total += labels.shape[0]

    return correct / total

def eval(model, val_dl, criterion):
    correct = 0
    total = 0

    model.eval()  
    with torch.no_grad():
        for data, labels in val_dl:   

            data, labels = data.to(device), labels.to(device)
            output = model(data)
            loss = criterion(output, labels) 

            if(len(output.shape) == 1):
                preds = torch.tensor([1 if output[i] >= 0.5 else 0 for i in range(output.shape[0])])
            else:
                preds = torch.argmax(output, dim=-1)
            
            correct += preds.eq(labels).sum().item()
            total += labels.shape[0]

    return correct / total 

