import torch
from torch import log_softmax
import torch.nn as nn
import time
from typing import Tuple
import torch
from torch.nn import Module
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
from tqdm import tqdm
from .misc import get_device


def train_epoch(
    model: Module,
    data_loader: DataLoader,
    opt: Optimizer,
    criterion: _Loss,
    disable_pbar: bool = False,
) -> Tuple[float, float]:
    """
    Train for 1 epoch
    """
    device = get_device()
    model = model.to(device)
    model.train()
    running_loss = correct = 0.0
    n_batches = len(data_loader)
    for (x, y) in tqdm(data_loader, ncols=80, disable=disable_pbar, leave=False):
        # if y.shape[0] < 128:
        #    continue

        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        opt.step()
        running_loss += loss.item()
        pred_class = torch.argmax(pred, dim=-1)
        if y.ndim == 2:
            y = torch.argmax(y, dim=-1)
        correct += (pred_class == y).sum().item()

    loss = running_loss / n_batches
    acc = correct / len(data_loader.dataset)
    return loss, acc


def test(model, data_loader):
    """
    test accuracy
    """
    device = get_device()
    model = model.to(device)
    model.eval()
    correct = 0.0
    with torch.no_grad():
        for (x, y) in data_loader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            pred_class = torch.argmax(pred, dim=1)
            correct += (pred_class == y).sum().item()
        acc = correct / len(data_loader.dataset)
    return acc


def train(
    model: Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    criterion: _Loss,
    opt: Optimizer,
    epochs: int = 10,
    sch: _LRScheduler = None,
    disable_pbar=False,
):
    """
    Train model
    """
    for epoch in range(1, epochs + 1):
        s = time.time()
        train_loss, train_acc = train_epoch(
            model, train_loader, opt, criterion, disable_pbar
        )
        test_acc = test(model, test_loader)
        if sch:
            sch.step()
        e = time.time()
        time_epoch = e - s
        print(
            "Epoch: {} train_loss: {:.3f} train_acc: {:.2f}%, test_acc: {:.2f}% time: {:.1f}".format(
                epoch, train_loss, train_acc * 100, test_acc * 100, time_epoch
            )
        )
