# originally modified from https://github.com/kuangliu/pytorch-cifar

'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR

from my_transforms import transform_train_224x224 as transform_train
from my_transforms import transform_test_224x224 as transform_test

import os
import argparse
import fs as pyfs
import numpy as np

from torchvision import models
model_list = [
 'alexnet',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'googlenet',
 'inception_v3',
 'mnasnet0_5',
 'mnasnet1_0',
 'mobilenet_v2',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext50_32x4d',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'squeezenet1_0',
 'squeezenet1_1',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'wide_resnet101_2',
 'wide_resnet50_2']

def get_model(model_name, pretrained = True):
    """ Return pre-trained torchvision model"""
    if model_name not in model_list:
        raise NotImplementedError(f"Model {model_name} not implemented")
    
    net = getattr(torchvision.models, model_name)(pretrained = pretrained)
    torch.manual_seed(42)
    if model_name == 'alexnet':
        net.classifier[6] = torch.nn.Linear(4096, 10)
    elif model_name == 'densenet121':
        net.classifier = torch.nn.Linear(1024, 10)
    elif model_name == 'densenet161':
        net.classifier = torch.nn.Linear(2208, 10)
    elif model_name == 'densenet169':
        net.classifier = torch.nn.Linear(1664, 10)
    elif model_name == 'googlenet':
        net.fc = torch.nn.Linear(1024, 10)
    elif model_name == 'mobilenet_v2':
        net.classifier[1] = torch.nn.Linear(1280, 10)
    elif model_name in ['resnet18', 'resnet34']:
        net.fc = torch.nn.Linear(512, 10)
    elif model_name in ['resnet50','resnet101','resnet152']:
        net.fc = torch.nn.Linear(2048, 10)
    elif model_name[:7] == 'resnext':
        net.fc = torch.nn.Linear(2048, 10)
    elif model_name[:10] == 'shufflenet':
        net.fc = torch.nn.Linear(1024, 10)
    elif model_name[:10] == 'squeezenet':
        net.classifier[1] = torch.nn.Conv2d(512, 10, kernel_size=(1, 1), stride=(1, 1))
    elif model_name[:3] == 'vgg':
        net.classifier[-1] = torch.nn.Linear(in_features=4096, out_features=10)
    elif model_name[:11] == 'wide_resnet':
        net.fc = torch.nn.Linear(2048, 10)
    else:
        raise NotImplementedError()
    return net

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training Pretrained Model')
parser.add_argument('--run', default=0, type=int, help='run number')
parser.add_argument('--epochs', default=250, type=int, help='Number of Epochs')
parser.add_argument('--batch_size', default=100, type=int, help='Batch size')
parser.add_argument('--lr', default=0.001, type=float, help='Learning rate')
parser.add_argument('--model_name', type=str, help='Model name')
parser.add_argument('--measure_frequency', default = 5, type=int, help='Frequency of measurement')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='Use pretrained weights')
parser.add_argument('--num_workers', default=4, type=int, help='Number of cpu workers')

args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')

batch_size = args.batch_size
num_workers = args.num_workers

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

testset101 = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testset101.data   = np.load('./data/cifar10.1_v6_data.npy')
testset101.targets = [int(i) for i in np.load('./data/cifar10.1_v6_labels.npy')]
testloader101 = torch.utils.data.DataLoader(
    testset101, batch_size=batch_size, shuffle=False, num_workers=num_workers)

trainset_eval = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_test)
trainloader_eval = torch.utils.data.DataLoader(
    trainset_eval, batch_size=batch_size, shuffle=False, num_workers=num_workers)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

stored_train_loss   = []
stored_test10_loss  = []
stored_test101_loss = []

stored_train_acc   = []
stored_test10_acc  = []
stored_test101_acc = []

# Model
# print('==> Building model..')

if args.pretrained:
  print("==> Building pre-trained model...")
else:
  print("==> Bilding randomly initialized model...")

net = get_model(args.model_name, args.pretrained)
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=1e-4)
# scheduler = CosineAnnealingLR(optimizer, args.epochs)

# Training
def train(epoch):
    global best_acc
    print('\nEpoch: %d' % epoch)
    net.train()

    gradient_steps = 0
    max_batch = len(trainloader)
    optimizer.zero_grad()

    # measurement at init
    if epoch == 0:
        acc_test_10,  loss_test_10  = test(epoch, 0, testloader)
        acc_test_101, loss_test_101 = test(epoch, 0, testloader101) 
        acc_train, loss_train       = test(epoch, 0, trainloader_eval, 10000) 
        # loss
        stored_train_loss.append(  [epoch, 0, loss_train])
        stored_test10_loss.append( [epoch, 0, loss_test_10])
        stored_test101_loss.append([epoch, 0, loss_test_101])
        # acc
        stored_train_acc.append(  [epoch, 0, acc_train])
        stored_test10_acc.append( [epoch, 0, acc_test_10])
        stored_test101_acc.append([epoch, 0, acc_test_101])

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        optimizer.step()  
        optimizer.zero_grad()

        step = max_batch*epoch + batch_idx

        if step % args.measure_frequency == 0:

            acc_test_10,  loss_test_10  = test(epoch, step, testloader)
            acc_test_101, loss_test_101 = test(epoch, step, testloader101) 
            acc_train, loss_train       = test(epoch, step, trainloader_eval, 10000) 

            if acc_train>0.90:
                args.measure_frequency = 50

            # loss
            stored_train_loss.append([epoch, step, loss_train])
            stored_test10_loss.append([epoch, step, loss_test_10])
            stored_test101_loss.append([epoch, step, loss_test_101])

            # acc
            stored_train_acc.append([epoch, step, acc_train])
            stored_test10_acc.append([epoch, step, acc_test_10])
            stored_test101_acc.append([epoch, step, acc_test_101])


def test(epoch, step, testloader, num_examples = None):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if num_examples is not None:
                if (batch_idx+1)*batch_size > num_examples:
                    break

    acc = correct/total
    print(f"Epoch {epoch}, step {step}: {acc:.4f}")
    net.train()
    return acc, test_loss/total

for epoch in range(args.epochs):
    train(epoch)
    # scheduler.step()

test(args.epochs, 0, testloader)
test(args.epochs, 0, testloader101) 

res = {'train_loss':   stored_train_loss, 
       'test_10_loss': stored_test10_loss,
       'test_101_loss':stored_test101_loss,
       'train_acc':    stored_train_acc, 
       'test_10_acc':  stored_test10_acc, 
       'test_101_acc': stored_test101_acc}


if args.model_name[:6] == "ResNet":
    save_name = args.model_name + '-' + str(args.width_factor)
else:
    save_name = args.model_name

# Save Results to disk
filename = f"{save_name}_lr-{args.lr}_run-{args.run}_pretrained-{args.pretrained}"
with open(f"{filename}.npz", 'wb') as f:
    np.savez(f, **res)
