import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm


def visualize_loss(losses):
    plt.figure(figsize=(12, 4))
    plt.plot(losses, label="Loss")
    plt.title("Training Loss")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid()
    plt.show()


def simple_train(ds, backbone, method, bsz=10_000, epochs=75, lr=1e-3):
    device = backbone.parameters().__next__().device
    dl = DataLoader(ds, batch_size=bsz, shuffle=True)
    opt = torch.optim.Adam(backbone.parameters(), lr=lr)

    losses = []

    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        for batch in dl:
            if type(batch) in (tuple, list):
                x0 = batch[0]
                y = batch[1]
                y = y.to(device).unsqueeze(1)
            else:
                x0 = batch
                y = None

            x0 = x0.to(device).unsqueeze(1)
            loss = method.loss(backbone, x0, y=y)

            opt.zero_grad()
            loss.backward()
            opt.step()

            losses.append(loss.detach().cpu().item())

    visualize_loss(losses)


def simple_train_by_iterations(
    train_dl, val_dl, backbone, method, train_steps=100_000, lr=1e-3
):
    device = backbone.parameters().__next__().device
    opt = torch.optim.Adam(backbone.parameters(), lr=lr)

    losses = []

    iteration = 0
    pbar = tqdm(total=train_steps, desc="Training Steps", postfix=f"Loss: {0:.4f}")

    while True:  # we dont track epochs
        for y, x0 in train_dl:
            y = y.to(device)
            x0 = x0.to(device)

            loss = method.loss(backbone, x0, y=y)

            opt.zero_grad()
            loss.backward()
            opt.step()

            loss_val = loss.detach().cpu().item()

            losses.append(loss_val)

            pbar.set_postfix(Loss=f"{loss_val:.4f}")
            pbar.update(1)
            iteration += 1

            # TODO add validation here

            # finish
            if iteration >= train_steps:
                pbar.close()
                break
        if iteration >= train_steps:
            break

    visualize_loss(losses)
