import torch
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train_classifier(
    model,
    train_data_loader,
    test_data_loader,
    optimizer,
    criterion,
    embedding_column_name,
    label_column_name,
    num_epochs: int = 5,
    evaluation_interval: int = 5,
):

    train_losses = []
    train_accuracies = []

    eval_losses = []
    eval_accuracies = []
    eval_indexes = []

    model = model.to(device)

    for epoch in (bar := tqdm(range(num_epochs + 1))):
        total_loss = 0.0
        correct_train = 0
        total_train = 0

        model.freeze_encoder()

        for batch in train_data_loader:

            image = batch[embedding_column_name].to(device)
            target = batch[label_column_name].to(device)

            optimizer.zero_grad()

            predicted = model(x=image)
            loss = criterion(predicted, target)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            _, predicted_labels = torch.max(predicted, 1)
            correct_train += (predicted_labels == target).sum().item()
            total_train += target.size(0)

        average_epoch_loss = total_loss / len(train_data_loader)
        train_losses.append(average_epoch_loss)

        train_epoch_accuracy = correct_train / total_train
        train_accuracies.append(train_epoch_accuracy)

        if epoch % evaluation_interval == 0:
            eval_loss, eval_accuracy = eval_classifier(
                model, test_data_loader, criterion, embedding_column_name, label_column_name
            )
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_accuracy)
            eval_indexes.append(epoch)
            model.train()
            bar.set_description(f"Epoch {epoch}, Test Loss: {eval_loss:.4f}, Test Accuracy: {eval_accuracy:.4f}")

    return train_losses, eval_losses, train_accuracies, eval_accuracies, eval_indexes


def eval_classifier(
    model,
    test_data_loader,
    criterion,
    embedding_column_name,
    label_column_name,
):
    model.eval()

    total_test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for test_batch in test_data_loader:

            test_image = test_batch[embedding_column_name].to(device)
            test_target = test_batch[label_column_name].to(device)

            test_predicted = model(x=test_image)

            test_loss = criterion(test_predicted, test_target)
            total_test_loss += test_loss.item()

            _, predicted_labels = torch.max(test_predicted, 1)
            correct += (predicted_labels == test_target).sum().item()
            total += test_target.size(0)

    average_test_loss = total_test_loss / len(test_data_loader)
    eval_accuracy = correct / total

    return average_test_loss, eval_accuracy
