'''Train CIFAR10 with PyTorch.'''

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random
import pickle

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import time

from model import get_model
from data import get_data, make_planeloader
from utils import get_loss_function, get_scheduler, get_random_images, produce_plot, get_noisy_images, AttackPGD
from evaluation import train, test, test_on_trainset, decision_boundary, test_on_adv
from options import options
from utils import simple_lapsed_time
from tqdm import tqdm
from set_seed import set_seed

args = options().parse_args()
print(args)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# save_path = args.save_net


# Data/other training stuff
set_seed(args.set_seed)
trainloader, testloader = get_data(args)
set_seed(args.set_seed)


# get train set data
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
raw_trainset = torchvision.datasets.CIFAR10(
    root='~/data', train=True, download=True, transform=transform_test)
raw_trainloader = torch.utils.data.DataLoader(
    raw_trainset, batch_size=args.bs, shuffle=True, num_workers=2)


test_accs = []
train_accs = []
net = get_model(args, device)

test_acc, predicted = test(args, net, testloader, device, 0)
print("scratch prediction ", test_acc)

criterion = get_loss_function(args)
if args.opt == 'SGD':
    optimizer = optim.SGD(net.parameters(), lr=args.lr,
                          momentum=0.9, weight_decay=args.weight_decay)
    scheduler = get_scheduler(args, optimizer)

elif args.opt == 'Adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)

elif args.opt.lower() == 'adamw':
    optimizer = torch.optim.AdamW(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)

# Train or load base network
print("Training the network or loading the network")

start = time.time()
best_acc = 0  # best test accuracy
best_epoch = 0

save_path = f'./saved_models/corr_{args.net}/{str(args.set_seed)}'
if args.load_net is None:

    for epoch in range(args.epochs):
        train_acc = train(args, net, raw_trainloader, optimizer, criterion, device, args.train_mode, sam_radius=args.sam_radius)

        test_acc, predicted = test(args, net, testloader, device, epoch)
        print(f'EPOCH: {epoch}/{args.epochs}, Train acc: {train_acc:.2f}, Test acc: {test_acc:.2f}')
        if args.dryrun:
            break
        if args.opt == 'SGD':
            scheduler.step()

        # Save checkpoint.
        if test_acc > best_acc:
            print(f'The best epoch is: {epoch}')
            os.makedirs(save_path, exist_ok=True)
            print(f'{save_path}/{args.save_net}.pth')
            if torch.cuda.device_count() > 1:
                state_dict = net.module.state_dict()
            else:
                state_dict = net.state_dict()
            torch.save(state_dict, f'{save_path}/{args.save_net}.pth')
    
            best_acc = test_acc
            best_epoch = epoch



else:
    # Check if the model is an instance of DataParallel before loading
    if isinstance(net, torch.nn.DataParallel):
        net = net.module  # Remove DataParallel wrapper
    net.load_state_dict(torch.load(args.load_net))
    

end = time.time()
simple_lapsed_time("Time taken to train the model", end-start)

