import torch
from tqdm import tqdm

def train(
    model, loader, criterion, optimizer, eval_loader=None, num_epochs=10, log_frequency=1, lr_scheduler=None, device=None,
):
    logs = []
    
    model.to(device)
    count_step = 0
    train_loss = 0.0
    for epoch in range(num_epochs):
        print(f"****** Epoch {epoch + 1} ******")
        model.train()
        for inputs, targets in loader:
            optimizer.zero_grad()
            count_step += 1
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            train_loss += loss.item()
            if count_step % log_frequency == 0:
                print("step {}, train_loss: {:.4f}, lr: {}".format(count_step, train_loss/log_frequency, optimizer.param_groups[0]["lr"]))
                train_loss = 0.0
            loss.backward()
            optimizer.step()
            if lr_scheduler is not None:
                lr_scheduler.step()
        
        train_acc = evaluate(model, loader, device=device)
        print("train_acc: {:.4f}".format(train_acc))

        if eval_loader is not None:
            val_acc = evaluate(model, eval_loader, device=device)
            print("val_acc: {:.4f}".format(val_acc))
            logs.append({"val_acc": val_acc})
    return logs


def evaluate(model, loader, device=None):
    model.eval()
    model.to(device)

    correct_preds = 0
    num_samples = 0
    with torch.no_grad():
        for inputs, targets in tqdm(loader, desc="Evaluate"):
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            preds = outputs.argmax(-1).cpu().numpy()
            correct_preds += (preds == targets.cpu().numpy().flatten()).sum()
            num_samples += inputs.shape[0]

    accuracy = 100.0 * correct_preds / num_samples
    return accuracy

    

def inference(model, loader):
    model.eval()
    all_preds = []
    with torch.no_grad():
        for batch in loader:
            all_preds.append(model(batch[0].to(DEVICE)).cpu())
    return torch.vstack(all_preds)