import os
import shutil
import time
#import update_loss_util

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from optim_util import SGDGD,GD_momnoise
import numpy as np
import torch.optim as optim


def train_loop(
    trainloader,
    model,
    criterion,
    optim_choice,
    optimizer,
    epoch,
    device):
    
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):

            #print("fc1.weights before the gradient step")
            #print(model.fc1.weight)
            #print("##############################")
            #print("################################")


            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)#[:,0]
            loss = criterion(outputs, targets)
            loss.backward()

            #print("GRADIENT-1")

            #for p in list(filter(lambda p: p.grad is not None, model.parameters())):
            #    print(p.grad.data)

            #print("##############################")
            #print("################################")

            optimizer.step()

            #print("GRADIENT-2")
            #print(model.fc1.weight.grad)
            #print("one element")
            #print(model.fc1.weight.grad[0][0])
            
            #print("##############################")
            #print("################################")

            #print("fc1.weights after the gradient step")
            #print(model.fc1.weight)
            #print("##############################")
            #print("################################")

            
            #print(model.fc1.weight.data)
            #print(model.fc1.bias.data)
            #print(model.fc2.weight.data)
            #print(model.fc2.bias.data)
            
            train_loss += loss.item()
            #predicted=outputs.round()

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    print ("Epoch [{}], Loss: {:.4f}"
                           .format(epoch+1,   train_loss/len(trainloader)))

    acc = 100.*correct/total
    #print("ACC: {}".format(acc))
    #jjk
    return train_loss/len(trainloader),acc  


def validate(
    testloader, 
    model, 
    criterion, 
    epoch,device
    ):
    
    
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)#[:,0]
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            #predicted=outputs.round()

            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = 100.*correct/total
    

    return acc







def create_optimizer(model, optim_choice, hparams):
        return optim.SGD(model.parameters(), lr= hparams['initial_lr'], 
                      momentum=hparams['momentum'],
                      weight_decay=hparams['weight_decay'])

        
def adjust_lr(optimizer, optim_choice, epoch,  hparams):
    if hparams['lr_sched'] == 'wr_default':
        sched_func = lr_wr_default
    lr_vals = []    
    for param_group in optimizer.param_groups:
            param_group['lr'] = sched_func(hparams['initial_lr'],
                       epoch, hparams['frst_ann'] , hparams['snd_ann'])
            lr_vals.append(param_group['lr'])
        
    return lr_vals
    

def lr_wr_default(lr,  epoch, frst_ann, snd_ann):
    
    new_lr = lr*(0.1**int(epoch >= frst_ann))*(0.1**int(epoch >= snd_ann))
    return new_lr
    
