import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from Layers import layers


def train(model,
          loss,
          optimizer,
          scheduler,
          dataloader,
          device,
          epoch,
          verbose,
          log_interval=100):
    model.train()
    total = 0
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device,
                               non_blocking=True), target.to(device,
                                                             non_blocking=True)
        optimizer.zero_grad()
        output = model(data)
        train_loss = loss(output, target)
        total += train_loss.item() * data.size(0)
        train_loss.backward()
        optimizer.step()
        if isinstance(scheduler, torch.optim.lr_scheduler.LambdaLR):
            scheduler.step()
#         if batch_idx % 50 == 0 and verbose:
#             print(batch_idx, total / (batch_idx + 1))
    return total / len(dataloader.dataset)


def eval(model, loss, dataloader, device, verbose):
    model.eval()
    total = 0
    correct1 = 0
    correct5 = 0
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total += loss(output, target).item() * data.size(0)
            _, pred = output.topk(5, dim=1)
            correct = pred.eq(target.view(-1, 1).expand_as(pred))
            correct1 += correct[:, :1].sum().item()
            correct5 += correct[:, :5].sum().item()
    average_loss = total / len(dataloader.dataset)
    accuracy1 = 100. * correct1 / len(dataloader.dataset)
    accuracy5 = 100. * correct5 / len(dataloader.dataset)
    return average_loss, accuracy1, accuracy5


def train_eval_loop(model, loss, optimizer, scheduler, train_loader,
                    test_loader, device, epochs, verbose):
    best_val_acc = 0
    test_loss, accuracy1, accuracy5 = eval(model, loss, test_loader, device,
                                           verbose)
    rows = [[np.nan, test_loss, accuracy1, accuracy5]]
    if verbose:
        running = tqdm(range(epochs))
    else:
        running = range(epochs)
    for epoch in running:
        train_loss = train(model, loss, optimizer, scheduler, train_loader,
                           device, epoch, verbose)
        test_loss, accuracy1, accuracy5 = eval(model, loss, test_loader, device,
                                               verbose)
        best_val_acc = max(best_val_acc, accuracy1)
        row = [train_loss, test_loss, accuracy1, accuracy5]
        if verbose:
            print(
                'Evaluation: Average loss: {:.4f}, Top 1 Accuracy: {}/{} ({:.2f}%)'
                .format(test_loss, accuracy1, len(test_loader.dataset),
                        accuracy1))
        if isinstance(scheduler, torch.optim.lr_scheduler.MultiStepLR):
            scheduler.step()
        rows.append(row)
    columns = ['train_loss', 'test_loss', 'top1_accuracy', 'top5_accuracy']
    return pd.DataFrame(rows, columns=columns), best_val_acc


def speed_test(model,
               loss,
               dummy_input,
               device,
               run_iter=100,
               backward=False,
               cuda=True):
    total = 0
    batch_cnt = 0
    sum_time = 0
    if backward:
        model.train()
    else:
        model.eval()
    data, target = dummy_input
    if cuda:
        data, target = data.to(device), target.to(device)
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
        enable_timing=True)

    # wampup the GPU
    with torch.no_grad():
        for i in range(10):
            _ = model(data)

    while True:
        if backward:
            model.zero_grad()
            starter.record()
            output = model(data)
            tmp_loss = loss(output, target)
            tmp_loss.backward()
        else:
            with torch.no_grad():
                starter.record()
                _ = model(data)
        ender.record()
        torch.cuda.synchronize()
        sum_time += starter.elapsed_time(ender)
        batch_cnt += 1
        if batch_cnt == run_iter:
            return sum_time / run_iter
