from tqdm import tqdm
import torch.nn as nn
from test import test
import numpy as np
import torch

criterion_classifier = nn.NLLLoss(reduction='mean')

def train(model, optimizer, trainloader, epochs=30, test_loader=None, m_eval=False, scheduler=None):
    t = tqdm(range(epochs))
    for epoch in t:
        corrects = 0
        total = 0
        for x, y in trainloader:
            loss = 0
            x = x.cuda()
            y = y.cuda()
            y_hat = model(x)

            loss += criterion_classifier(y_hat, y)
            _, predicted = y_hat.max(1)
            corrects += predicted.eq(y).sum().item()
            total += y.size(0)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            t.set_description(f'epoch:{epoch} current accuracy:{round(corrects / total * 100, 2)}%')
        if scheduler is not None:
            scheduler.step()
        if test_loader is not None:
            test(model, test_loader, m_eval)
    return (corrects / total)

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cutmix_data(x, y, alpha=1.0, use_cuda=True):
    lam = np.random.beta(alpha, alpha)
    if use_cuda:
        index = torch.randperm(x.size()[0]).cuda()
    else:
        index = torch.randperm(x.size()[0])
    y_a, y_b = y, y[index]
    
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_mixup(model, optimizer, trainloader, epochs=30, test_loader=None, m_eval=False, scheduler=None, alpha=1):
    t = tqdm(range(epochs))
    for epoch in t:
        corrects = 0
        total = 0
        for x, y in trainloader:
            loss = 0
            x = x.cuda()
            y = y.cuda()
            mixed_x, y_a, y_b, lam = mixup_data(x, y, alpha=alpha)
            y_hat = model(x)

            loss += mixup_criterion(criterion_classifier, y_hat, y_a, y_b, lam)
            _, predicted = y_hat.max(1)
            corrects += predicted.eq(y).sum().item()
            total += y.size(0)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            t.set_description(f'epoch:{epoch} current accuracy:{round(corrects / total * 100, 2)}%')
        if scheduler is not None:
            scheduler.step()
        if test_loader is not None:
            test(model, test_loader, m_eval)
    return (corrects / total)

def train_cutmix(model, optimizer, trainloader, epochs=30, test_loader=None, m_eval=False, scheduler=None, alpha=1):
    t = tqdm(range(epochs))
    for epoch in t:
        corrects = 0
        total = 0
        for x, y in trainloader:
            loss = 0
            x = x.cuda()
            y = y.cuda()
            mixed_x, y_a, y_b, lam = cutmix_data(x, y, alpha=alpha)
            y_hat = model(x)

            loss += mixup_criterion(criterion_classifier, y_hat, y_a, y_b, lam)
            _, predicted = y_hat.max(1)
            corrects += predicted.eq(y).sum().item()
            total += y.size(0)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            t.set_description(f'epoch:{epoch} current accuracy:{round(corrects / total * 100, 2)}%')
        if scheduler is not None:
            scheduler.step()
        if test_loader is not None:
            test(model, test_loader, m_eval)
    return (corrects / total)

