import os
from networks import LeNet5Feats, ResNetFeats18, classifier
#import resnet
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import hypergrad as hg
#from utils import save_checkpoint
import time
import matplotlib.pyplot as plt
import pickle
import numpy as np
import random

def compute_grad_norm(grad_list):
    grad_norm = 0.
    for g in grad_list:
        if g is None:
            continue
        grad_norm += torch.sum(torch.mul(g, g))
    grad_norm = torch.sqrt(grad_norm)
    return grad_norm

    # @torch.no_grad()
    # def get_step_size(self, closure_deterministic, loss, params_current, grad_current, grad_norm, pp_norm,
    #                  for_backtracking=False):
    #         # reset step size
    #         # step_size=self.state['step_size']
    #         # step_size = ut.reset_step(step_size=self.state['step_size'],
    #         #                     n_batches_per_epoch=self.n_batches_per_epoch,
    #         #                     gamma=self.gamma,
    #         #                     reset_option=self.reset_option,
    #         #                     init_step_size=self.init_step_size,
    #         #                     eta_max=self.eta_max,
    #         #                     step=self.state['step'])
    #         # sls line search
    #         for e in range(100):
    #             # make potential step
    #             self.try_sgd_precond_update(self.hparams, step_size, hparams_current, grad_current, momentum=0.)

    #             # if for_backtracking:
    #             #     loss_next = closure_deterministic(for_backtracking=True)
    #             # else:
    #             #     loss_next = closure_deterministic()

    #             # compute new loss
    #             self.state['n_forwards'] += 1
    #             found, step_size = self.check_armijo_precond_conditions(
    #                                     step_size, 
    #                                     loss, 
    #                                     loss_next,  
    #                                     pp_norm)
    #             if found == 1:
    #                 break
    #         if found == 0:
    #             step_size = 1e-6
    #     return step_size