import time
import random 
import numpy as np
import torch 
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import math
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
from torch.utils.data import random_split
import os
from PIL import Image
#import pretrainedmodels
#import pretrainedmodels.utils as utils
import torchvision.models as models
import torch.nn.functional as F
from layers.feat_noise import Noise
from tqdm import tqdm

from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types

from utils import report
from generator.conditional_data import ConditionalCIFAR10, ConditionalMNIST


class CIFAR10(nn.Module):
    def __init__(self):
        super(CIFAR10, self).__init__()
        self.features = self._make_layers()
        self.fc1 = nn.Linear(3200,256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256,256)
        self.dropout = nn.Dropout(p=0.5)
        self.fc3 = nn.Linear(256,10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc3(out)
        return out

    def _make_layers(self):
        layers=[]
        in_channels= 3
        layers += [nn.Conv2d(in_channels, 64, kernel_size=3),
                   nn.BatchNorm2d(64),
                   nn.ReLU()]
        layers += [nn.Conv2d(64, 64, kernel_size=3),
                   nn.BatchNorm2d(64),
                   nn.ReLU()]
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        layers += [nn.Conv2d(64, 128, kernel_size=3),
                   nn.BatchNorm2d(128),
                   nn.ReLU()]
        layers += [nn.Conv2d(128, 128, kernel_size=3),
                   nn.BatchNorm2d(128),
                   nn.ReLU()]
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        
        return nn.Sequential(*layers)


    def predict(self, image):
        self.eval()
        image = torch.clamp(image,0,1)
        image = Variable(image, volatile=True).view(1,3, 32,32)
        if torch.cuda.is_available():
            image = image.cuda()
        output = self(image)
        _, predict = torch.max(output.data, 1)
        return predict[0]
    
    def predict_batch(self, image):
        self.eval()
        image = torch.clamp(image,0,1)
        image = Variable(image, volatile=True)
        if torch.cuda.is_available():
            image = image.cuda()
        output = self(image)
        _, predict = torch.max(output.data, 1)
        return predict


class MNIST_Jalal(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
        self.fc = nn.Linear(3136, 1024)
        self.final_fc = nn.Linear(1024, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.shape[0], -1)
        out = F.relu(self.fc(out))
        return self.final_fc(out)
    
    def predict(self, image):
        self.eval()
        image = torch.clamp(image,0,1)
        image = Variable(image, volatile=True).view(1,1,28,28)
        if torch.cuda.is_available():
            image = image.cuda()
        output = self(image)
        _, predict = torch.max(output.data, 1)
        return predict[0]

    def predict_batch(self, image):
        self.eval()
        image = torch.clamp(image,0,1)
        image = Variable(image, volatile=True)
        if torch.cuda.is_available():
            image = image.cuda()
        output = self(image)
        _, predict = torch.max(output.data, 1)
        return predict
    

class MNIST(nn.Module):
    def __init__(self):
        super(MNIST, self).__init__()
        self.features = self._make_layers()
        self.fc1 = nn.Linear(1024,200)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(200,200)
        self.dropout = nn.Dropout(p=0.5)
        self.fc3 = nn.Linear(200,10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc3(out)
        return out

    def _make_layers(self):
        layers=[]
        in_channels= 1
        layers += [nn.Conv2d(in_channels, 32, kernel_size=3),
                   nn.BatchNorm2d(32),
                   nn.ReLU()]
        layers += [nn.Conv2d(32, 32, kernel_size=3),
                   nn.BatchNorm2d(32),
                   nn.ReLU()]
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        layers += [nn.Conv2d(32, 64, kernel_size=3),
                   nn.BatchNorm2d(64),
                   nn.ReLU()]
        layers += [nn.Conv2d(64, 64, kernel_size=3),
                   nn.BatchNorm2d(64),
                   nn.ReLU()]
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        
        return nn.Sequential(*layers)


    def predict(self, image):
        self.eval()
        image = torch.clamp(image,0,1)
        image = Variable(image, volatile=True).view(1,1,28,28)
        if torch.cuda.is_available():
            image = image.cuda()
        output = self(image)
        _, predict = torch.max(output.data, 1)
        return predict[0]

    def predict_batch(self, image):
        self.eval()
        image = torch.clamp(image,0,1)
        image = Variable(image, volatile=True)
        if torch.cuda.is_available():
            image = image.cuda()
        output = self(image)
        _, predict = torch.max(output.data, 1)
        return predict


def show_image(img):
    """
    Show MNSIT digits in the console.
    """
    remap = "  .*#"+"#"*100
    img = (img.flatten()+.5)*3
    if len(img) != 784: return
    for i in range(28):
        print("".join([remap[int(round(x))] for x in img[i*28:i*28+28]]))

        
def load_mnist_data(state, mode, shuffle_test=False, class_ix=None):
    """ Load MNIST data from torchvision.datasets 
        input: None
        output: minibatches of train and test sets 
    """
    torch.manual_seed(state["seed"])
    torch.cuda.manual_seed(state["seed"])
    np.random.seed(state["seed"])
    random.seed(state["seed"])
    torch.backends.cudnn.deterministic = True
    
    # MNIST Dataset
    if mode == 'generator':
        print("Loading generator version")
        transform = transforms.Compose([
            transforms.RandomAffine(25, shear=20),
            transforms.ToTensor()
        ])
        if class_ix is not None:
            train_dataset = ConditionalMNIST(root=f'./data/mnist/{class_ix}', class_ix=class_ix, train=False, transform=transform, download=True)
        else:
            train_dataset = dsets.MNIST(root='./data/mnist', train=False, transform=transform, download=True)
    else:
        train_dataset = dsets.MNIST(root='./data/mnist', train=True, transform=transforms.ToTensor(), download=True)
    
    test_dataset = dsets.MNIST(root='./data/mnist', train=False, transform=transforms.ToTensor(), download=True)
    
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=state['batch_size'], shuffle=True)
    # Adversary must shuffle during target sample search
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=state['test_batch_size'], shuffle=shuffle_test)

    return train_loader, test_loader, train_dataset, test_dataset


def load_cifar10_data(state, mode, shuffle_test=False, class_ix=None):
    """ Load MNIST data from torchvision.datasets 
        input: None
        output: minibatches of train and test sets 
    """
    torch.manual_seed(state["seed"])
    torch.cuda.manual_seed(state["seed"])
    np.random.seed(state["seed"])
    random.seed(state["seed"])
    torch.backends.cudnn.deterministic = True
    
    # CIFAR10 Dataset
    if mode == 'generator':
        print("Loading generator version")
        transform = transforms.Compose([
            transforms.ColorJitter(brightness=.1, contrast=.1, saturation=.1, hue=.1),
#             transforms.RandomAffine(15, shear=5),
            transforms.ToTensor()
        ])
        if class_ix is not None:
            train_dataset = ConditionalCIFAR10(f'./data/cifar10-py/{class_ix}', class_ix=class_ix, download=True, train=False, transform=transform)
        else:
            train_dataset = dsets.CIFAR10('./data/cifar10-py', download=True, train=False, transform=transform)
    else:
        train_dataset = dsets.CIFAR10('./data/cifar10-py', download=True, train=True, transform= transforms.ToTensor())
    
    test_dataset = dsets.CIFAR10('./data/cifar10-py', download=True, train=False, transform= transforms.ToTensor())

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=state['batch_size'], shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=state['test_batch_size'], shuffle=shuffle_test)

    return train_loader, test_loader, train_dataset, test_dataset


# https://github.com/NVIDIA/DALI/blob/master/docs/examples/use_cases/pytorch/resnet50/main.py
class HybridTrainPipe(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
                 shard_id, num_shards, seed, dali_cpu=False, resize=256, normalize=True):
        super(HybridTrainPipe, self).__init__(batch_size,
                                              num_threads,
                                              device_id,
                                              seed=seed)
        self.input = ops.FileReader(file_root=data_dir,
                                    shard_id=shard_id,
                                    num_shards=num_shards,
                                    random_shuffle=True)
        dali_device = 'cpu' if dali_cpu else 'gpu'
        decoder_device = 'cpu' if dali_cpu else 'mixed'
        # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
        # without additional reallocations
        device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
        host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
        self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
                                                 device_memory_padding=device_memory_padding,
                                                 host_memory_padding=host_memory_padding,
                                                 random_aspect_ratio=[0.8, 1.25],
                                                 random_area=[0.1, 1.0],
                                                 num_attempts=100)
        self.res = ops.Resize(device=dali_device,
                              resize_x=resize,
                              resize_y=resize,
                              interp_type=types.INTERP_TRIANGULAR)
        
        if normalize:
            # Imagenet vals
            prior_mean = [0.485, 0.456, 0.406]
            prior_std = [0.229, 0.224, 0.225]
        else:
            print("TrainPipe is not normalizing!")
            prior_mean = None
            prior_std = None
            
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(crop, crop),
                                            mean=prior_mean,
                                            std=prior_std)
        self.coin = ops.CoinFlip(probability=0.5)
        print(f"DALI {dali_device} variant")

    def define_graph(self):
        rng = self.coin()
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        images = images / 255.
        output = self.cmnp(images.gpu(), mirror=rng)
        return [output, self.labels]

    
def load_imagenet_train(state, mode, normalize=True):
    """ Load MNIST data from torchvision.datasets
        input: None
        output: minibatches of train and test sets
    """
    # For the sake of attacking in [0, 1], must divide down to [0, 1] and only normalize after attacking. 
    # Generator mode is owned by adversary.
    torch.manual_seed(state["seed"])
    torch.cuda.manual_seed(state["seed"])
    np.random.seed(state["seed"])
    random.seed(state["seed"])
    torch.backends.cudnn.deterministic = True
    # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, num_workers=24, shuffle=True)
    # if mode == 'generator':
    #     data_full = os.path.join(state["dataset_path"], "generator_train")
    # else:
    #     data_full = os.path.join(state["dataset_path"], "victim_train")
        
    pipe = HybridTrainPipe(batch_size=state["batch_size"],
                           num_threads=state["num_threads"],
                           device_id=0,
                           data_dir=os.path.join(state["dataset_path"], "train"),
                           crop=224,
                           seed=state["seed"],
                           dali_cpu=False,
                           num_shards=1,
                           shard_id=0,
                           normalize=normalize)
    pipe.build()
    train_loader = DALIClassificationIterator(pipe, reader_name="Reader")

    return train_loader


def load_imagenet_generator(state, normalize=False, class_ix=None):
    # Load imagenet validation set with augmentation.
    torch.manual_seed(state["seed"])
    torch.cuda.manual_seed(state["seed"])
    np.random.seed(state["seed"])
    random.seed(state["seed"])
    torch.backends.cudnn.deterministic = True
    
    pipe = HybridTrainPipe(batch_size=state["batch_size"],
                           num_threads=state["num_threads"],
                           device_id=0,
                           data_dir=os.path.join(state["dataset_path"], "val"),
                           crop=224,
                           seed=state["seed"],
                           dali_cpu=False,
                           num_shards=1,
                           shard_id=0,
                           normalize=normalize)
    pipe.build()
    gen_loader = DALIClassificationIterator(pipe, reader_name="Reader")

    return gen_loader

    
class HybridValPipe(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, crop,
                 shard_id, num_shards, seed, normalize=True, resize=256, shuffle=False):
        super(HybridValPipe, self).__init__(batch_size,
                                           num_threads,
                                            device_id,
                                            seed=seed)
        self.input = ops.FileReader(file_root=data_dir,
                                    shard_id=shard_id,
                                    num_shards=num_shards,
                                    random_shuffle=shuffle)
        self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
        self.res = ops.Resize(device="gpu",
                              resize_x=resize,
                              resize_y=resize,
                              interp_type=types.INTERP_TRIANGULAR)
        if normalize:
            # Imagenet vals
            prior_mean = [0.485, 0.456, 0.406]
            prior_std = [0.229, 0.224, 0.225]
        else:
            print("ValPipe is not normalizing!")
            prior_mean = None
            prior_std = None
        
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(crop, crop),
                                            mean=prior_mean,
                                            std=prior_std)

    def define_graph(self):
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        images = images / 255.
        output = self.cmnp(images)
        return [output, self.labels]


def load_imagenet_test(state, normalize=False, shuffle_test=False):
    # For the sake of attacking in [0, 1], must divide down to [0, 1] and only normalize after attacking. 
    torch.manual_seed(state["seed"])
    torch.cuda.manual_seed(state["seed"])
    np.random.seed(state["seed"])
    random.seed(state["seed"])
    torch.backends.cudnn.deterministic = True
    
#     pipe = HybridValPipe(batch_size=state["test_batch_size"],
#                            num_threads=state["num_threads"],
#                            device_id=0,
#                            data_dir=os.path.join(state["dataset_path"], "val"),
#                            crop=224,
#                            num_shards=1,
#                            shard_id=0,
#                            seed=state["seed"],
#                            normalize=normalize,
#                            shuffle=shuffle_test)
#     pipe.build()
#     test_loader = DALIClassificationIterator(pipe, reader_name="Reader")
    val_dataset = dsets.ImageFolder(
            os.path.join(state["dataset_path"], "val"),
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ]))
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=state["test_batch_size"], shuffle=shuffle_test)
    

    return val_loader, val_dataset


def train_simple_mnist(model, train_loader):
    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Train the Model
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = Variable(images)
            labels = Variable(labels)
        
            # Forward + Backward + Optimize
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
            if (i+1) % 100 == 0:
                print ('Epoch [%d/%d], Iter [%d] Loss: %.4f' 
                    %(epoch+1, num_epochs, i+1, loss.data[0]))


def train_mnist(model, train_loader, hyperparams):
    # Loss and Optimizer
    model.train()
    num_epochs = hyperparams['num_epochs']
    lr = hyperparams['lr']
    momentum = hyperparams['momentum']
    save_freq = hyperparams['save_freq']
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, nesterov=True)
    # Train the Model
    for epoch in range(num_epochs):
        with tqdm(total=len(train_loader)) as pb:
            for i, (images, labels) in enumerate(train_loader):
                if torch.cuda.is_available():
                    images, labels = images.cuda(), labels.cuda()
                optimizer.zero_grad()
                images = Variable(images)
                labels = Variable(labels)

                # Forward + Backward + Optimize
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                str_loss = f"{loss.cpu().data.numpy():.4f}"
                pb.update(1)
                pb.set_postfix(epoch=epoch, loss=str_loss)

                # if (i+1) % 100 == 0:
                #     print(f'Epoch [{epoch+1}/{num_epochs}], Iter {i+1} Loss: {loss.detach().cpu().numpy():.4f}')

        if (epoch + 1) % save_freq == 0:
            ckpt_dir = os.path.join(hyperparams['ckpt_dir'], f"mnist_epoch-{epoch+1}_seed-{hyperparams['seed']}.ckpt")
            torch.save(model.state_dict(), ckpt_dir)
            print(f"Saved to {ckpt_dir}")



def test_mnist(model, test_loader):
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
    correct = 0
    total = 0
    with tqdm(total=len(test_loader)) as pb:
        for images, labels in test_loader:
            if torch.cuda.is_available():
                images, labels = images.cuda(), labels.cuda()
            images = Variable(images)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

            str_acc = f"{100.0 * (correct / total):.4f}"
            pb.update(1)
            pb.set_postfix(accuracy=str_acc)

    print('Test Accuracy of the model on the 10000 test images: %.2f %%' % (100.0 * correct / total))

def cross_entropy(log_input, target):
    product = log_input * target
    loss = torch.sum(product)
    loss *= -1/log_input.size()[0]
    return loss
         

def train_cifar10(model, train_loader, test_loader, hyperparams):
    # Loss and Optimizer
    model.train()
    num_epochs = hyperparams['num_epochs']
    lr = hyperparams['lr']
    momentum = hyperparams['momentum']
    save_freq = hyperparams['save_freq']
    test_freq = hyperparams['test_freq']
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, nesterov=True)
    # Train the Model
    for epoch in range(num_epochs):
        if epoch%10==0 and epoch!=0:
            lr = lr * 0.95
            momentum = momentum * 0.95
            optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, nesterov=True)
        with tqdm(total=len(train_loader)) as pb:
            for i, (images, labels) in enumerate(train_loader):
                if torch.cuda.is_available():
                    images, labels = images.cuda(), labels.cuda()
                optimizer.zero_grad()
                images = Variable(images)
                labels = Variable(labels)

                # Forward + Backward + Optimize
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                # if (i+1) % 100 == 0:
                #     print(f'Epoch [{epoch+1}/{num_epochs}], Iter {i+1} Loss: {loss.detach().cpu().numpy():.4f}')
                str_loss = f"{loss.cpu().data.numpy():.4f}"
                pb.update(1)
                pb.set_postfix(epoch=epoch, loss=str_loss)

        if (epoch + 1) % save_freq == 0:
            ckpt_dir = os.path.join(hyperparams['ckpt_dir'], f"cifar_epoch-{epoch+1}_seed-{hyperparams['seed']}.ckpt")
            torch.save(model.state_dict(), ckpt_dir)
            print(f"Saved to {ckpt_dir}")
        if (epoch + 1) % test_freq == 0:
            model.eval()
            test_cifar10(model, test_loader)
            model.train()

    return model


def test_cifar10(model, test_loader):
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
    correct = 0
    total = 0
    with tqdm(total=len(test_loader)) as pb:
        for images, labels in test_loader:
            if torch.cuda.is_available():
                images, labels = images.cuda(), labels.cuda()
            images = Variable(images)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

            str_acc = f"{100.0 * (correct / total):.4f}"
            pb.update(1)
            pb.set_postfix(accuracy=str_acc)

    print('Test Accuracy of the model on the 10000 test images: %.4f %%' % (100.0 * correct / total))


# From Nidia DALI sample code
def adjust_learning_rate(initial_lr, optimizer, epoch, step, len_epoch):
    """LR schedule that should yield 76% converged accuracy with batch size 256"""
    factor = epoch // 30

    if epoch >= 80:
        factor = factor + 1

    lr = initial_lr*(0.1**factor)

    """Warmup"""
    if epoch < 5:
        lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

        
# From Nvidia DALI sample code
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def train_imagenet(model, train_loader, test_loader, state):
    # Loss and Optimizer
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 
                                 lr=state["learning_rate"], 
                                 weight_decay=state['weight_decay'])
    n_batches = int(math.ceil(train_loader._size // state["batch_size"])) + 1
    # Train the Model
    for epoch in range(num_epochs):
        with tqdm(total=n_batches) as pb:
            for i, data in enumerate(train_loader):
                images = data[0]["data"]
                labels = data[0]["label"].squeeze().cuda().long()

                if torch.cuda.is_available():
                    images, labels = images.cuda(), labels.cuda()
                
                adjust_learning_rate(state["learning_rate"], optimizer, epoch, i, n_batches)
                optimizer.zero_grad()

                # Forward + Backward + Optimize
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                str_loss = f"{loss.cpu().data.numpy():.4f}"
                pb.update(1)
                pb.set_postfix(epoch=epoch, loss=str_loss)

            train_loader.reset()

            if (epoch + 1) % state["save_freq"] == 0:
                ckpt_dir = os.path.join(state['ckpt_path'],
                                        f"imagenet_epoch-{epoch + 1}_seed-{state['seed']}.ckpt")
                torch.save(model.state_dict(), ckpt_dir)
                report(state['report_path'], f"Saved to {ckpt_dir}")
            if (epoch + 1) % state["test_freq"] == 0:
                model.eval()
                test_imagenet(model, test_loader, state)
                model.train()

    return model


def test_imagenet(model, test_loader, state):
    # Test the Model
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
    correct = 0
    total = 0
    n_batches = int(test_loader._size // state["batch_size"]) + 1
    with tqdm(total=n_batches) as pb:
        for data in test_loader:
            images = data[0]["data"]
            labels = data[0]["label"].squeeze().cuda().long()

            if torch.cuda.is_available():
                images, labels = images.cuda(), labels.cuda()

            # images = Variable(images)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

            str_acc = f"{100.0 * (correct / total):.4f}"
            pb.update(1)
            pb.set_postfix(accuracy=str_acc)

        test_loader.reset()

    report(state['report_path'], 'Test Accuracy of the model on test images: %.4f %%' % (100.0 * correct / total))


class ToSpaceBGR(object):
    def __init__(self, is_bgr):
        self.is_bgr = is_bgr
    def __call__(self, tensor):
        if self.is_bgr:
            new_tensor = tensor.clone()
            new_tensor[0] = tensor[2]
            new_tensor[2] = tensor[0]
            tensor = new_tensor
        return tensor

class ToRange255(object):
    def __init__(self, is_255):
        self.is_255 = is_255
    def __call__(self, tensor):
        if self.is_255:
            tensor.mul_(255)
        return tensor

def save_model(model, filename):
    """ Save the trained model """
    torch.save(model.state_dict(), filename)

def load_model(model, filename):
    """ Load the training model """
    model.load_state_dict(torch.load(filename))
