'''Train CIFAR10 with PyTorch.'''

import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5"

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import time

from model import get_model
from data import get_data
from evaluation import test
from options import options
from utils import simple_lapsed_time



def get_generalization_gap(args, net, testloader, device):
    torch.manual_seed(args.set_seed)

    # Evaluate on the test set
    test_acc, _ = test(args, net, testloader, device, 1)
    print(f"Test Accuracy: {test_acc:.2f}%")

    # Evaluate on the training set
    # Note: Here we use the same transform as the test set (without data augmentation)
    # to better assess the generalization gap
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Create a training set loader without data augmentation
    # trainset = trainloader.dataset
    raw_trainset = torchvision.datasets.CIFAR10(
        root='~/data', train=True, download=True, transform=transform_test)
    raw_trainloader = torch.utils.data.DataLoader(
        raw_trainset, batch_size=100, shuffle=False, num_workers=2)

    train_acc, _ = test(args, net, raw_trainloader, device, 1)
    print(f"Train Accuracy: {train_acc:.2f}%")

    # Calculate generalization gap
    generalization_gap = train_acc - test_acc
    print(f"Generalization Gap: {generalization_gap:.2f}%")
    return generalization_gap


if __name__ == 'main':
    args = options().parse_args()
    print(args)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'