"""Evaluates the model"""

import argparse
import logging
import os

import numpy as np
import torch
from spikingjelly.clock_driven import functional
from torch.autograd import Variable



from torchvision import datasets, transforms




import argparse
import logging

import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.autograd import Variable
from tqdm import tqdm

import utils

#from densenet import *
from torchvision import datasets, transforms
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', default='experiments/base_model', help="Directory of params.json")
parser.add_argument('--restore_file', default='best', help="name of the file in --model_dir \
                     containing weights to load")


def evaluate(model, loss_fn, dataloader, metrics, params):



    model.eval()


    summ = []
    correct=0
    summ = []
    correct_sum = 0
    test_sum=0


    for data_batch, labels_batch in dataloader:



        data_batch, labels_batch = data_batch.to(params.device), labels_batch.to(params.device)

        data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)
        
        
        
        
        
        output_batch = model(data_batch)
        correct_sum += (output_batch.max(1)[1] == labels_batch.to(params.device)).float().sum().item()
        test_sum +=  labels_batch.numel()
        functional.reset_net(model)

    print('bbb')
    test_accuracy = correct_sum / test_sum
        
        
    torch.save({
                
                'state_dict': model.state_dict(),
                
            }, os.path.join('./', 'rescifar100.pth.tar'))
            #print('saved')
    print(test_accuracy)
    return test_accuracy



def evaluate_kd(model, dataloader, metrics, params):



    model.eval()


    summ = []
    correct_sum = 0
    test_sum=0



    for data_batch, labels_batch in dataloader:


        if params.device:
            data_batch, labels_batch = data_batch.to(params.device), labels_batch.to(params.device)


        data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)

        output_batch = model(data_batch)
        correct_sum += (output_batch.max(1)[1] == labels_batch.to(params.device)).float().sum().item()
        test_sum +=  labels_batch.numel()
        functional.reset_net(model)

    print('bbb')
    test_accuracy = correct_sum / test_sum

    #torch.save(model.state_dict(), "./resstudent.pt")
    #torch.save(model, "./resstudent.pth")
    print('savedpyres')
    #print(test_accuracy)
    return test_accuracy


