import argparse
import os
import numpy as np
import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR
import time
import datetime

from .model_utils import make_and_restore_model, make_and_restore_model_2
from .datasets import ImageNet
from .datasetsFromOtherRepo import get_datasetSpecial, DATASETS
from .train_utilsFromOtherRepo import AverageMeter, accuracy, init_logfile, log

#%%
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('--batch', default=256, type=int, metavar='N',
                    help='batchsize (default: 256)')
parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
args = parser.parse_args()

#%%

def _imagenet(_dir) -> Dataset:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    return datasets.ImageFolder(_dir, transform)

#%%

img_dir = '/imagenet/dataset/val/'
os.environ['IMAGENET_LOC_ENV'] = '/imagenet/'

def save_model_parallel(existing_path, new_path):
    ds = ImageNet(existing_path)
    model, checkpoint = make_and_restore_model(arch='resnet50', dataset= ds, resume_path= existing_path, 
                        parallel=True, pytorch_pretrained=False, add_custom_forward=False, momentum= 0.1)
    torch.save({'state_dict':model.state_dict(),
            }, new_path)

save_model_parallel('./model/imagenet_l2_eps765.pt', './model/imagenet_l2_eps765_parallel.pt')
save_model_parallel('./model/imagenet_linf_4.pt', './model/imagenet_linf_4_parallel.pt')
print('============= Exit ===================')


