import os
import shutil
import time
#import update_loss_util

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import numpy as np
import torch.optim as optim


def train_loop(
    trainloader,
    model,
    criterion,
    optim_choice,
    optimizer,
    epoch,
    device):
    
    model.train()
    train_loss = 0
    correct = 0
    total = 0
 
    for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()  
    acc = 100.*correct/total
    print('Accuracy of the model on the training images: {} %'.format(acc))
    print('Loss on training images: {} '.format(train_loss/len(trainloader)))
    return acc 


def validate(
    testloader, 
    model, 
    criterion, 
    epoch,device
    ):
    
    
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = 100.*correct/total
    
    print('Accuracy of the model on the test images: {} %'.format(acc))
    print('Loss on the test images: {} '.format(test_loss/len(testloader)))

    return acc

def create_optimizer(model, optim_choice, hparams):
        return optim.SGD(model.parameters(), lr= hparams['initial_lr'], 
                      momentum=hparams['momentum'],
                      weight_decay=hparams['weight_decay'])

        
def adjust_lr(optimizer, optim_choice, epoch,  hparams):
    if hparams['lr_sched'] == 'wr_default':
        sched_func = lr_wr_default
    lr_vals = []    
    for param_group in optimizer.param_groups:
            param_group['lr']= sched_func(hparams['initial_lr'],
                       epoch, hparams['frst_ann'] , hparams['snd_ann'])
            lr_vals.append(param_group['lr'])

        
    return lr_vals

def lr_wr_default(lr,  epoch, frst_ann, snd_ann):
    
    
    new_lr = lr*(0.1**int(epoch >= frst_ann))*(0.1**int(epoch >= snd_ann))

    return new_lr
    
