'''Train CIFAR10 with PyTorch.'''
from Imagenet.resnet_b import ResNet50_b_imagenet
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import json
import logging 
import copy
import time 

from resnet import *
from resnet_weighted import *
from resnet_nobn import *
from resnet_lastbn import *
from resnet_middlebn import *
from resnet_firstbn import *
from resnet_fixup import *
from resnet_b_fixup import *
from resnet_skip import *
from resnet_b_skip import *
from resnet_type_b import *
from resnet_type_b_nobn import *
from nfnets import *
from nfnets_b import *

from Imagenet.resnet_lastbn import *
from Imagenet.resnet_middlebn import *
from Imagenet.resnet_firstbn import *
from Imagenet.resnet import *
from Imagenet.resnet_b_nobn import *
from Imagenet.resnet_nobn import *
from Imagenet.resnet_skip import *
from Imagenet.resnet_b_skip import *
from Imagenet.resnet_fixup import *
from Imagenet.resnet_b_fixup import *
from Imagenet.nfnets import *
from Imagenet.nfnets_b import *

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


# from utils import progress_bar

def init_kaiming_normal(model):
    for m in model.modules():
        if isinstance(m,(nn.Conv2d, nn.Linear)):
            fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight)
            std = math.sqrt(2/fan_in)
            nn.init.normal_(m.weight, mean=0.0, std=std)

logger = None

def setup_logger(args):
    global logger
    if logger == None:
        logger = logging.getLogger()
    else:  # wish there was a logger.close()
        for handler in logger.handlers[:]:  # make a copy of the list
            logger.removeHandler(handler)

    args_copy = copy.deepcopy(args)
    # copy to get a clean hash
    # use the same log file hash if iterations or verbose are different
    # these flags do not change the results
    args_copy.iters = 1
    args_copy.verbose = False
    args_copy.log_interval = 1
    args_copy.seed = 0

    log_path = args.result_dir + '/{0}.log'.format(args.seed)

    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%H:%M:%S')

    fh = logging.FileHandler(log_path)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

def print_and_log(msg):
    global logger
    print(msg)
    logger.info(msg)





parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
parser.add_argument('--init-method', type=str, default='torch-init-default', choices=['ortho-block-alpha', 'torch-default-init', 'kaiming-normal'])
# parser.add_argument('--alpha', type=float, default=1)
parser.add_argument('--result-dir', type=str, default='results')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--model', type=str, default='resnet', choices=['resnet101-skip', 'resnet-alpha','resnet101-nobn', 'resnet101-fixup', 'nfnet-b', 'resnet-skip', 'resnet-b-skip', 'resnet-b-fixup', 'resnet-b-nobn','resnet', 'resnet-weighted', 'resnet-nobn', 'resnet-b', 'resnet-lastbn', 'resnet50-lastbn', 'resnet50-middlebn', 'resnet50-firstbn', 'resnet-firstbn', 'resnet-middlebn', 'resnet50-fixup', 'resnet50-nobn', 'resnet-fixup', 'nfnet', 'resnet50', 'nfnet50'])
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet', 'cifar100', 'tiny-imagenet'])
parser.add_argument('--scheduler', type=str, default='cosine', choices=['cosine', 'step'])
parser.add_argument('--fp-every', default=10, type=int, help='fisher penalty iters')
parser.add_argument('--fp-scale', default=1e-3, type=float, help='fisher penalty scale')
parser.add_argument('--fp', default=False, type=bool, help='fisher penalty')
parser.add_argument('--batch-size', default=128, type=int, help='batch size')
parser.add_argument('--epochs', default=200, type=int, help='epochs')
parser.add_argument('--alpha', default=1.0, type=float, help='alpha scaling res block')


args = parser.parse_args()
try:
    os.makedirs(args.result_dir)
except:
    pass

setup_logger(args)
print_and_log(args)
print_and_log(time.time())

with open(args.result_dir + '/args.json', 'w') as f:
    json.dump(args.__dict__, f, sort_keys=True, indent=4)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print_and_log('==> Preparing data..')
print_and_log(time.time())

if args.dataset == 'cifar10':
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False)

    classes = ('plane', 'car', 'bird', 'cat', 'deer',
            'dog', 'frog', 'horse', 'ship', 'truck')

if args.dataset == 'cifar100':
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
    ])

    trainset = torchvision.datasets.CIFAR100(
        root='data/', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR100(
        root='data/', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False)


if args.dataset == 'imagenet':
    
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
    ])
    trainset = torchvision.datasets.ImageFolder(
        'data/train', transform=transform_train)
    print_and_log(time.time())
    print_and_log('Loading Imagenet')
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True, num_workers=12)
    print_and_log(time.time())

    testset = torchvision.datasets.ImageFolder(
        root='data/val', transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=12)
    
    print_and_log('==> dataloaders defined..')


if args.dataset == 'tiny-imagenet':
    
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
    ])
    
    trainset = torchvision.datasets.ImageFolder(
        'train', transform=transform_train)
    print_and_log(time.time())
    print_and_log('Loading Imagenet')
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True, num_workers=12)
    print_and_log(time.time())

    testset = torchvision.datasets.ImageFolder(
        root='val', transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=12)
    
    print_and_log('==> dataloaders defined..')


# Model
print_and_log('==> Building model..')
torch.manual_seed(args.seed)
print_and_log(time.time())

if args.dataset == 'cifar10':
    if args.model == 'resnet':
        net = ResNet18_llskip([3, 32, 32], 10)
    if args.model == 'resnet-alpha':
        net = ResNet18_llskip_alpha([3, 32, 32], 10)
        net._set_alpha(args.alpha)
    if args.model == 'resnet-b':
        net = ResNet18_llidentity([3, 32, 32], 10)
    if args.model == 'resnet-skip':
        net = ResNet18_skip([3, 32, 32], 10)
    if args.model == 'resnet-b-skip':
        net = ResNet18_b_skip([3, 32, 32], 10)
    if args.model == 'resnet50':
        net = ResNet50_llskip([3, 32, 32], 10)
    if args.model =='resnet-weighted':
        net = ResNet18_weighted([3, 32, 32], 10)
    if args.model == 'resnet-nobn':
        net = ResNet18_nobn([3, 32, 32], 10)
    if args.model == 'resnet-b-nobn':
        net = ResNet18_llidentity_nobn([3, 32, 32], 10)
    if args.model == 'resnet50-nobn':
        net = ResNet50_nobn([3, 32, 32], 10)
    if args.model == 'resnet-lastbn':
        net = ResNet18_lastbn([3, 32, 32], 10)
    if args.model == 'resnet50-lastbn':
        net = ResNet50_lastbn([3, 32, 32], 10)
    if args.model == 'resnet-firstbn':
        net = ResNet18_firstbn([3, 32, 32], 10)
    if args.model == 'resnet50-firstbn':
        net = ResNet50_firstbn([3, 32, 32], 10)
    if args.model == 'resnet-middlebn':
        net = ResNet18_middlebn([3, 32, 32], 10)
    if args.model == 'resnet50-middlebn':
        net = ResNet50_middlebn([3, 32, 32], 10)
    if args.model == 'resnet-fixup':
        net = ResNet18_fixup([3, 32, 32], 10)
    if args.model == 'resnet-b-fixup':
        net = ResNet18_b_fixup([3, 32, 32], 10)
    if args.model == 'resnet50-fixup':
        net = ResNet50_fixup_imagenet([3, 32, 32], 10)
    if args.model == 'nfnet':
        net = NFResNet18([3, 32, 32], 10)
    if args.model == 'nfnet-b':
        net = NFResNet18_b([3, 32, 32], 10)
    if args.model == 'nfnet50':
        net = NFResNet50([3, 32, 32], 10)

if args.dataset == 'cifar100':
    if args.model == 'resnet50':
        net = ResNet50_llskip([3, 32, 32], 100)
    if args.model == 'resnet101-nobn':
        net = ResNet101_nobn([3, 32, 32], 100)
    if args.model == 'resnet101-fixup':
        net = ResNet101_fixup([3, 32, 32], 100)
    if args.model == 'resnet101-skip':
        net = ResNet101_skip([3, 32, 32], 100)
    if args.model == 'resnet-b':
        net = ResNet50_llidentity([3, 32, 32], 100)
    if args.model == 'resnet-skip':
        net = ResNet50_skip([3, 32, 32], 100)
    if args.model == 'resnet-b-skip':
        net = ResNet50_b_skip([3, 32, 32], 100)
    if args.model == 'resnet-b-nobn':
        net = ResNet50_llidentity_nobn([3, 32, 32], 100)
    if args.model == 'resnet':
        net = ResNet18_llskip([3, 32, 32], 100)
    if args.model =='resnet-weighted':
        net = ResNet18_weighted([3, 32, 32], 100)
    if args.model == 'resnet-nobn':
        net = ResNet18_nobn([3, 32, 32], 100)
    if args.model == 'resnet50-nobn':
        net = ResNet50_nobn([3, 32, 32], 100)
    if args.model == 'resnet-lastbn':
        net = ResNet50_lastbn([3, 32, 32], 100)
    if args.model == 'resnet-middlebn':
        net = ResNet50_middlebn([3, 32, 32], 100)
    if args.model == 'resnet-firstbn':
        net = ResNet50_firstbn([3, 32, 32], 100)
    if args.model == 'resnet-fixup':
        net = ResNet50_fixup([3, 32, 32], 100)
    if args.model == 'resnet-b-fixup':
        net = ResNet50_b_fixup([3, 32, 32], 100)
    if args.model == 'nfnet':
        net = NFResNet50([3, 32, 32], 100)
    if args.model == 'nfnet-b':
        net = NFResNet50_b([3, 32, 32], 100)


if args.dataset == 'imagenet':
    if args.model == 'resnet':
        net = ResNet101([3, 224, 224], 1000)
    if args.model == 'resnet-nobn':
        net = ResNet101_nobn_imagenet([3, 224, 224], 1000)
    if args.model == 'resnet50':
        net = ResNet50([3, 224, 224], 1000)
    if args.model == 'nfnet':
        net = NFResNet50_imagenet([3, 224, 224], 1000)

if args.dataset == 'tiny-imagenet':
    # if args.model == 'resnet':
    #     net = ResNet101([3, 224, 224], 200)
    if args.model == 'resnet-nobn':
        net = ResNet50_nobn_imagenet([3, 224, 224], 200)
    if args.model == 'resnet-b-nobn':
        net = ResNet50_b_nobn_imagenet([3, 224, 224], 200)
    if args.model == 'resnet-lastbn':
        net = ResNet50_lastbn_imagenet([3, 224, 224], 200)
    if args.model == 'resnet-middlebn':
        net = ResNet50_middlebn_imagenet([3, 224, 224], 200)
    if args.model == 'resnet-firstbn':
        net = ResNet50_firstbn_imagenet([3, 224, 224], 200)
    if args.model == 'resnet-fixup':
        net = ResNet50_fixup_imagenet([3, 224, 224], 200)
    if args.model == 'resnet-b-fixup':
        net = ResNet50_B_fixup_imagenet([3, 224, 224], 200)
    if args.model == 'resnet-skip':
        net = ResNet50_skip_imagenet([3, 224, 224], 200)
    if args.model == 'resnet-b-skip':
        net = ResNet50_b_skip_imagenet([3, 224, 224], 200)    
    if args.model == 'resnet50':
        net = ResNet50([3, 224, 224], 200)
    if args.model == 'resnet-b':
        net = ResNet50_b_imagenet([3, 224, 224], 200)
    if args.model == 'nfnet':
        net = NFResNet50_imagenet([3, 224, 224], 200)
    if args.model == 'nfnet-b':
        net = NFResNet50_b_imagenet([3, 224, 224], 200)


net = net.to(device)



print_and_log('==> Initializing Model ..')

if args.init_method == 'ortho-block-alpha':
    print_and_log('ortho-block-alpha')
    print_and_log(time.time())

    net._ortho_block_init()

if args.init_method == 'torch-default-init':
    print_and_log('torch-default-init')
    
    pass

if args.init_method == 'kaiming-normal':
    print_and_log('kaiming-normal')
    init_kaiming_normal(net)



if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    # assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(args.result_dir + '/ckpt_'+str(args.seed)+'.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

if torch.cuda.device_count() > 1:
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    net = nn.DataParallel(net)



criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)

if args.scheduler == 'cosine':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
else:
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training
def fp_loss(model, scale, data, batch_size, num_classes, loss):
    # Takes the current model gradient and computes the L2 norm of the sum of the gradients of the model
    # Scales this quantity with 'scale' and adds to the existing gradient value
    # https://arxiv.org/pdf/2012.14193.pdf
    model.zero_grad()
    out = model(data)
    labels = torch.randint(0, num_classes, (batch_size,)).to(data.device)
    pred_loss = loss(out, labels)
    pred_loss.backward()
    fp = 0
    for name, m in model.named_parameters(): 
        if m.grad is not None:
            # use of detach
            fp += m.grad.detach().norm().item()**2
    model.zero_grad()
    return scale * fp

def train(epoch):
    print_and_log('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    cnt = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.type(torch.LongTensor).to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        if (cnt % args.fp_every == 0) and args.fp:
            fisher_penalty = fp_loss(net, args.fp_scale, inputs, outputs.shape[0], outputs.shape[1], criterion)
            loss += fisher_penalty
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        cnt += 1
        
        
        if batch_idx % 100 == 0:
            print_and_log('Train Epoch: {} [{}] \t Loss: {:.6f}'.format(
                epoch, batch_idx, loss.item()))
            
        

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    cnt = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            # print_and_log('Entering val iteration: {}'.format(batch_idx))

            inputs, targets = inputs.to(device), targets.type(torch.LongTensor).to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            cnt += 1
            

    print('Test Loss: ', test_loss / cnt)
    print('Test Accuracy: ', 100 * correct / total)
    print_and_log('Test Accuracy: {}'.format(100 * correct / total))
            
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'optimizer' : optimizer.state_dict()
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, args.result_dir + '/ckpt_'+str(args.seed)+'.pth')
        best_acc = acc
    return acc


acc_list = []
for epoch in range(start_epoch, start_epoch+args.epochs):
    print_and_log('Training an Epoch')
    print_and_log(time.time())

    train(epoch)
    acc = test(epoch)
    acc_list.append(acc)
    scheduler.step()
    
    torch.save(acc_list, args.result_dir + '/test_acc_'+str(args.seed)+'.pkl')
