import torch
from loguru import logger

logger.remove()
EARLY_STOPPING_PATIENCE = 5


def train_loop(
    dataloader, model, loss_fn_list, loss_weights, optimizer, log=None, print_every=10
):
    size = len(dataloader.dataset)
    train_loss = 0
    model.train()
    for batch, batch_items in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(**batch_items)
        losses = [
            weight * loss_fn(pred, **batch_items)
            for weight, loss_fn in zip(loss_weights, loss_fn_list)
        ]
        total_loss = torch.sum(torch.stack(losses))
        if log:
            log.info(f"Batch: {batch}, Loss: {losses}, Loss: {total_loss}")

        # Backpropagation
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += total_loss.item()

        if batch % print_every == 0:
            loss, current = total_loss.item(), batch * dataloader.batch_size
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    train_loss /= batch + 1
    return train_loss


def test_loop(dataloader, model, loss_fn_list, loss_weights, metrics):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    metrics_logger = {metric.__name__: {"correct": 0, "total": 0} for metric in metrics}
    loss_logger = {loss.__class__.__name__: {"loss_value": 0} for loss in loss_fn_list}

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    test_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for batch, batch_items in enumerate(dataloader):
            pred = model(**batch_items)
            for weight, loss_fn in zip(loss_weights, loss_fn_list):
                loss_logger[loss_fn.__class__.__name__]["loss_value"] += (
                    weight * loss_fn(pred, **batch_items).item()
                )

            for metric in metrics:
                correct_batch, total_batch = metric(pred, **batch_items)
                metrics_logger[metric.__name__]["correct"] += correct_batch
                metrics_logger[metric.__name__]["total"] += total_batch

    for loss in loss_fn_list:
        loss_value = loss_logger[loss.__class__.__name__]["loss_value"]
        loss_logger[loss.__class__.__name__]["loss_avg"] = loss_value / (batch + 1)
        test_loss += loss_value

    for metric in metrics:
        correct = metrics_logger[metric.__name__]["correct"]
        total = metrics_logger[metric.__name__]["total"]
        metrics_logger[metric.__name__]["accuracy"] = correct / total * 100

    test_loss /= batch + 1
    loss_out = ", ".join([f'{k}={v["loss_avg"]}' for k, v in loss_logger.items()])
    metrics_out = ", ".join([f'{k}={v["accuracy"]}' for k, v in metrics_logger.items()])
    print(
        f"\nTest Results:\nTotal Loss: {test_loss}, Losses: {loss_out}, Metrics: {metrics_out}\n"
    )
    return test_loss, metrics_logger


def train_test_loop(
    train_dataloader,
    test_dataloader,
    model,
    loss_fn_list,
    loss_weights,
    optimizer,
    scheduler,
    metrics,
    artifacts_dir,
    epochs=5,
):
    model_save_dir = artifacts_dir / "trained-models"
    model_save_dir.mkdir(parents=True, exist_ok=True)
    log_dir = artifacts_dir / "logs"
    batch_logger = logger.bind(task="batch")
    epoch_logger = logger.bind(task="epoch")
    results_logger = logger.bind(task="result")
    logger.add(
        log_dir / "batch.log",
        mode="w",
        filter=lambda record: record["extra"]["task"] == "batch",
    )
    logger.add(
        log_dir / "epoch.log",
        mode="w",
        filter=lambda record: record["extra"]["task"] == "epoch",
    )
    logger.add(
        log_dir / "result.log",
        mode="w",
        filter=lambda record: record["extra"]["task"] == "result",
    )
    results_logger.info(f"{'|'.join([metric.__name__ for metric in metrics])}")

    curr_patience_count = 0
    best_test_loss = float("inf")
    for t in range(epochs):
        batch_logger.info(f"Running Epoch: {t}")
        train_loss = train_loop(
            train_dataloader, model, loss_fn_list, loss_weights, optimizer
        )
        test_loss, results = test_loop(
            test_dataloader, model, loss_fn_list, loss_weights, metrics=metrics
        )

        results_logger.info(
            f"{'|'.join([str(round(results[metric.__name__]['accuracy'], 4)) for metric in metrics])}"
        )
        epoch_logger.info(
            f"Epoch: {t}, Train Loss: {train_loss}, Test Loss: {test_loss}"
        )
        scheduler.step()
        if test_loss < best_test_loss:
            torch.save(
                model.state_dict(),
                model_save_dir / "best.pt",
            )
            best_test_loss = test_loss
            curr_patience_count = 0
        else:
            curr_patience_count += 1
            if curr_patience_count >= EARLY_STOPPING_PATIENCE:
                epoch_logger.info(
                    f"Test loss has not improved for {curr_patience_count}, "
                    f"Breaking at epoch: {t}"
                )
                break

        torch.save(
            model.state_dict(),
            model_save_dir / f"epoch-{t}.pt",
        )
    print("Done!")
