import torch.optim as optim
from utilities import *

from copy import deepcopy
from utilities import validate

def adjust_learning_rate(optimizer, epoch,lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


def cos_lr(base_lr, epoch, iteration, num_iter, num_epoch, warmup=False):
    """
    cosine learning rate schedule
    from https://github.com/d-li14/mobilenetv2.pytorch
    :param epoch: current epoch
    :param iteration: current iteration
    :param num_iter: the number of the iteration in a epoch
    :param num_epoch: the number of the total epoch
    :param warmup: learning rate warm-up in first 5 epoch
    :return: learning rate
    """
    from math import cos, pi

    warmup_epoch = 5 if warmup else 0
    warmup_iter = warmup_epoch * num_iter
    current_iter = iteration + epoch * num_iter
    max_iter = num_epoch * num_iter

    lr = base_lr * (1 + cos(pi * (current_iter - warmup_iter) / (max_iter - warmup_iter))) / 2

    if epoch < warmup_epoch:
        lr = base_lr * current_iter / warmup_iter

    return lr


def adjust_cos_llearning_rate( base_lr,epoch, iteration, num_iter, num_epoch, optimizer):
    lr = cos_lr(base_lr, epoch, iteration, num_iter, num_epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


def full_imagenet_train_nn(model,train_loader, testloader, criterion, device,max_epoch):
    #chage this manually so far....
    cosine = True
    best_acc = 0.0
    if cosine == True:
        base_lr = 0.05
    else:
        base_lr = 0.1
    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=1e-4)
    for epoch in range(max_epoch):
        if cosine ==True:
            lr = base_lr
        else:
            lr = adjust_learning_rate(optimizer, epoch, base_lr)

        if epoch%10==0:
            print(lr)
        # train for one epoch
        best_acc = one_epoch_imagenet_train_nn(model,train_loader, testloader, criterion, device,epoch,optimizer,best_acc,base_lr,max_epoch,cosine)

        temp_acc = validate(testloader, model, criterion, device, initial=False)
        if temp_acc >= best_acc:
            print('-----------------------------------------')
            print('New best acc: ' + str(temp_acc))
            print('-----------------------------------------')
            best_acc = temp_acc

    return best_acc

def one_epoch_imagenet_train_nn(model,train_loader, testloader, criterion, device,epoch,optimizer,best_acc,base_lr,max_epoch,cosine=False):
    max_iter = len(train_loader)
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        if cosine ==True:
            lr = adjust_cos_llearning_rate(base_lr, epoch, i, max_iter, max_epoch, optimizer)

        data_time.update(time.time() - end)


        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 500 == 0:
            if cosine == True:
                print('New learning rate is: '+ str(lr))
            progress.display(i)
    return best_acc






def full_cifar_train_nn(model, trainloader,testloader,criterion, device, train_mode,max_epochs):
    model.train()
    state_dict = None
    best_acc = 0

    lr = 0.1
    if train_mode == 2 or train_mode == 3:
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9,nesterov= True, weight_decay=1e-4)
    else:
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    for epoch in range(max_epochs):
        if train_mode== 0:
            if epoch < 80:
                lr = 0.1
            elif epoch < 120:
                lr = 0.01
            else:
                lr = 0.001
        elif train_mode== 1:
            if epoch < 120:
                lr = 0.1
            elif epoch < 160:
                lr = 0.01
            else:
                lr = 0.001
        elif train_mode== 2:
            if epoch < 150:
                lr = 0.1
            elif epoch < 225:
                lr = 0.01
            else:
                lr = 0.001
        elif train_mode== 3:
            if epoch < 120:
                lr = 0.1
            elif epoch < 160:
                lr = 0.01
            else:
                lr = 0.001
        else: #this is the defaul here maybe come up with more modes for training ....
            if epoch < 80:
                lr = 0.1
            elif epoch < 120:
                lr = 0.01
            else:
                lr = 0.001
        #set learningrate to given value
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        for i, data in enumerate(trainloader, 0):

            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        model_loss = loss_net(model, criterion, trainloader, device)
        model.eval()
        model_acc = acc_net(model, testloader, device)
        if model_acc >= best_acc:
            best_acc = model_acc
            state_dict = deepcopy(model.state_dict())
        if epoch%10 == 0:
            print('[Epoch %d] Train Loss: %.7f, Test Acc: %.3f' % (epoch,model_loss,model_acc))
        model.train()

        #if epoch%10 == 0:
        #    model_loss = loss_net(model, criterion, trainloader, device)
        #    model.eval()
        #    model_acc = acc_net(model,testloader,device)
        #    print('[Epoch %d] Train Loss: %.7f, Test Acc: %.3f' % (epoch,model_loss,model_acc))
        #    model.train()

    print('Finished Training. Best acc: '+ str(best_acc))
    return best_acc,state_dict





def fine_tune_nn(model, trainloader,testloader,criterion, device,max_epochs=80):
    model.train()
    best_acc = 0
    lr = 0.01
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    for epoch in range(max_epochs):
        if epoch < 40:
            lr = 0.01
        else:
            lr = 0.001
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        model_loss = loss_net(model, criterion, trainloader, device)
        model.eval()
        model_acc = acc_net(model, testloader, device)
        model.train()

        if model_acc > best_acc:
            best_acc = model_acc

        if epoch % 10 == 0:
            print('[Epoch %d] Train Loss: %.7f, Test Acc: %.3f' % (epoch, model_loss, model_acc))


    print('Finished Fine-tuning')
    return best_acc




def pre_train_nn(model, trainloader,testloader,criterion, device,pre_train_epochs= 10):
    model.train()
    lr = 0.1
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    for epoch in range(pre_train_epochs):

        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        if epoch % 10 == 0:
            model_loss = loss_net(model, criterion, trainloader, device)
            model.eval()
            model_acc = acc_net(model, testloader, device)
            print('[Epoch %d] Train Loss: %.7f, Test Acc: %.3f' % (epoch, model_loss, model_acc))
            model.train()

    print('Finished Pretraining-tuning')



