import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

import argparse
import os
from pprint import pprint

from utils import set_seed, PoisonDataset, make_and_restore_tinyimagenet_model, TinyImageNet
from train import train_model, eval_model


def make_data(args):
    if args.data_type in ['PoisoningLinf', 'PoisoningL2', 'Quality']:
        transform_train = transforms.Compose([
            transforms.RandomCrop(64, 4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.ToTensor()
    transform_test = transforms.ToTensor()

    if args.data_type in ['Quality', 'Naive']:
        train_set = TinyImageNet(args.data_path, train=True, transform=transform_train)
    else:
        train_set = PoisonDataset(args.data_path, data_type=args.data_type, transform=transform_train)
    test_set = TinyImageNet(args.data_path, train=False, transform=transform_test)

    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True)
    return train_loader, test_loader

def main(args):
    train_loader, test_loader = make_data(args)
    set_seed(args.seed)
    if not os.path.isfile(args.model_path):
        model = make_and_restore_tinyimagenet_model(args.arch)
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        schedule = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=args.lr_step)
        writer = SummaryWriter(args.tensorboard_path)
        train_model(args, model, optimizer, schedule, train_loader, test_loader, writer)

    model = make_and_restore_tinyimagenet_model(args.arch, args.model_path)
    
    eval_model(args, model, test_loader)


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Training classifiers for Tiny-ImageNet')

    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--arch', default='ResNet18', type=str, choices=['MLP', 'VGG16', 'ResNet18', 'DenseNet121', 'WRN28-10'])
    parser.add_argument('--train_loss', default='ST', type=str, choices=['ST', 'AT', 'TRADES', 'THRM'])
    parser.add_argument('--constraint', default='Linf', choices=['Linf', 'L2'], type=str)
    parser.add_argument('--eps', default=8/255, type=float)
    parser.add_argument('--data_type', default='Quality', choices=['Naive', 'Noise', 'Mislabeling', 'Poisoning', 'Quality'])
    parser.add_argument('--device', default=0, type=int)

    parser.add_argument('--beta', default=6, type=float)
    
    args = parser.parse_args()

    import os 
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)
    
    # Training options
    args.epochs = 60
    args.batch_size = 64
    args.lr = 0.01
    if args.arch == 'MLP':
        args.lr = 0.01
    elif args.arch == 'VGG16' and args.data_type in ['Noise', 'Mislabeling']:
        args.lr = 0.01
    args.weight_decay = 5e-4
    if args.data_type in ['Noise', 'Mislabeling']:
        args.weight_decay = 0
    args.lr_step = 0.1
    args.lr_milestones = [50, 55]
    args.log_gap = 5
    # Attack options
    args.step_size = args.eps / 4
    args.num_steps = 10
    args.random_restarts = 1

    if args.data_type == 'Poisoning':
        args.data_type = args.data_type + args.constraint

    # Miscellaneous
    args.out_dir = 'results/tinyimagenet'
    args.data_path = '../datasets/tinyimagenet'
    args.exp_name = '{}-{}-{}-{}-seed{}'.format(args.arch, args.train_loss, args.data_type, args.constraint, args.seed)
    args.tensorboard_path = os.path.join(os.path.join(args.out_dir, args.exp_name), 'tensorboard')
    args.model_path_best = os.path.join(os.path.join(args.out_dir, args.exp_name), 'checkpoint_best.pth')
    args.model_path_last = os.path.join(os.path.join(args.out_dir, args.exp_name), 'checkpoint_last.pth')
    args.model_path = args.model_path_last

    pprint(vars(args))

    torch.backends.cudnn.benchmark = True
    main(args)


"""
Traceback (most recent call last):
  File "/data/zsj/robust hypocritical/code/train_tinyimagenet.py", line 103, in <module>
    main(args)
  File "/data/zsj/robust hypocritical/code/train_tinyimagenet.py", line 37, in main
    train_loader, test_loader = make_data(args)
  File "/data/zsj/robust hypocritical/code/train_tinyimagenet.py", line 29, in make_data
    train_set = PoisonDataset(args.data_path, data_type=args.data_type, transform=transform_train)
  File "/data/zsj/robust hypocritical/code/utils.py", line 189, in __init__
    self.data, self.targets = torch.load(self.file_path)
  File "/data/zsj/.conda/envs/zsj/lib/python3.9/site-packages/torch/serialization.py", line 699, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "/data/zsj/.conda/envs/zsj/lib/python3.9/site-packages/torch/serialization.py", line 230, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/data/zsj/.conda/envs/zsj/lib/python3.9/site-packages/torch/serialization.py", line 211, in __init__
    super(_open_file, self).__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '../datasets/tinyimagenet/Noise.data'
"""