import torch,time
import torch.nn as nn 
import torch.nn.functional as F 


def compute_accuracy_and_loss(model, data_loader,criterion,device):
	correct_pred, num_examples = 0, 0
	running_loss = 0
	for i, (features, targets) in enumerate(data_loader):
	    features = features.to(device)
	    targets = targets.to(device)

	    logits = model(features)
	    loss = criterion(logits, targets)
	    running_loss += loss.item()

	    _, predicted_labels = torch.max(logits, 1)
	    num_examples += targets.size(0)
	    correct_pred += (predicted_labels == targets).sum()
	return correct_pred.float()/num_examples * 100, running_loss/len(data_loader)

def train_epochs(model, optimizer, criterion, train_loader, device, num_epochs,):
	start_time = time.time()
	valid_acc_max = 0
	valid_acc_lst, valid_loss_lst = [], []
	for epoch in range(num_epochs): 
		train_loss = 0.0
		model.train()
		for batch_idx, (features, targets) in enumerate(train_loader):
			optimizer.zero_grad()

			features = features.to(device)
			targets = targets.to(device)
			    
			### FORWARD AND BACK PROP
			logits = model(features)
			loss = criterion(logits, targets)
			loss.backward()
			optimizer.step()
			train_loss += loss.item()

			### LOGGING
			if not batch_idx % 50:
				print ('Epoch: %03d/%03d | Batch %04d/%04d | Loss: %.4f' 
				       %(epoch+1, num_epochs, batch_idx, 
				         len(train_loader), train_loss/(batch_idx+1)))

				model.eval()
				with torch.no_grad(): # save memory during inference
					valid_acc, valid_loss = compute_accuracy_and_loss(model, train_loader,criterion, device=device)
					print('Epoch: %03d/%03d | Valid Acc: %.3f%% | Valid loss: %.3f' % (
					      epoch+1, num_epochs, valid_acc, valid_loss))

					if valid_acc >= valid_acc_max:
						print('Validation acc increase ({:.6f} --> {:.6f}).  Saving model ...'.format(
						valid_acc_max,
						valid_acc))
						valid_acc_max = valid_acc
					valid_acc_lst.append(valid_acc)
					valid_loss_lst.append(valid_loss)
		 

		print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))   
	print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
	return valid_acc_lst, valid_loss_lst
