'''Train CIFAR10/CIFAR100 with PyTorch.'''
from __future__ import print_function
import json
import math
import os
import pdb
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from tensorboardX import SummaryWriter
from utils.network_utils import get_network
from utils.data_utils import get_dataloader
from utils.common_utils import PresetLRScheduler, makedirs

try:
    from frob import FactorizedConv, frobdecay, frobenius_norm, non_orthogonality, patch_module
    from make import compress_model, parameter_count
except ImportError:
    print("Failed to import factorization")

# fetch args
parser = argparse.ArgumentParser()
parser.add_argument('--learning_rate', default=0.1, type=float)
parser.add_argument('--weight_decay', default=3e-3, type=float)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--network', default='vgg', type=str)
parser.add_argument('--depth', default=19, type=int)
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--epoch', default=10, type=int)
parser.add_argument('--decay_every', default=60, type=int)
parser.add_argument('--decay_ratio', default=0.1, type=float)
parser.add_argument('--device', default=0, type=int)
parser.add_argument('--resume', '-r', action='store_true')
parser.add_argument('--load_path', default='', type=str)
parser.add_argument('--log_dir', default='runs/pretrain', type=str)
parser.add_argument('--rank-scale', default=0.0, type=float)
parser.add_argument('--wd2fd', action='store_true')
parser.add_argument('--spectral', action='store_true')
parser.add_argument('--kaiming', action='store_true')
parser.add_argument('--target-ratio', default=0.0, type=float)
parser.add_argument('--auto-resume', action='store_true')
args = parser.parse_args()

# init model
net = get_network(network=args.network,
                  depth=args.depth,
                  dataset=args.dataset,
                  kaiming=args.kaiming)
origpar = parameter_count(net)
print('Original weight count:', origpar)
if args.rank_scale or args.target_ratio:
    if args.network == 'vgg':
        names = [str(i) for i, child in enumerate(net.feature) if i and type(child) == nn.Conv2d]
        denoms = [child.out_channels*child.kernel_size[0]*child.kernel_size[1] 
                  for child in net.feature if type(child) == nn.Conv2d]
        def compress(model, rank_scale, spectral=False):
            for name, denom in zip(names, denoms):
                patch_module(model.feature, name, FactorizedConv,
                             rank_scale=rank_scale,
                             init='spectral' if spectral else lambda X: nn.init.normal_(X, 0., math.sqrt(2. / denom)))
            return model
        no_decay = names if args.wd2fd else []
        skiplist = [] if args.wd2fd else names
    else:
        def compress(model, rank_scale, spectral=False):
            blocks = [block for layer in list(model.children())[2:-1] for block in layer]
            names = [['conv1', 'conv2'] for _ in blocks]
            for block, namelist in zip(blocks, names):
                if hasattr(block, 'downsample') and not block.downsample is None:
                    namelist.append('downsample.0')
            for module, namelist in zip(blocks, names):
                for name in namelist:
                    patch_module(module, name, FactorizedConv,
                                 rank_scale=rank_scale,
                                 init='spectral' if spectral else lambda X: nn.init.kaiming_normal_(X))
            return model
        no_decay = ['.conv1', '.conv2', 'downsample'] if args.wd2fd else []
        skiplist = [] if args.wd2fd else ['.conv1', '.conv2', 'downsample']
    if args.target_ratio:
        if args.spectral:
            _, rank_scale = compress_model(net, compress, args.target_ratio)
            compress(net, rank_scale, spectral=True)
        else:
            net, _ = compress_model(net, compress, args.target_ratio)
    else:
        compress(net, args.rank_scale, spectral=args.spectral)
    newpar = parameter_count(net)
    print('Compressed weight count:', newpar)
    print('Compression ratio:', newpar / origpar)
else:
    no_decay, skiplist = [], []

torch.backends.cudnn.benchmark = True
torch.cuda.set_device(args.device)
net = net.to(args.device)

# init dataloader
trainloader, testloader = get_dataloader(dataset=args.dataset,
                                         train_batch_size=args.batch_size,
                                         test_batch_size=args.batch_size,
                                         num_workers=4,
                                         pin_memory=True)

# init optimizer and lr scheduler
optimizer_grouped_parameters = [
        {'params': [p for n, p in net.named_parameters() 
                    if not any(nd in n for nd in no_decay)], 
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in net.named_parameters() 
                    if any(nd in n for nd in no_decay)], 
         'weight_decay': 0.0}
]
optimizer = optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate, momentum=0.9)
lr_schedule = {0: args.learning_rate,
               int(args.epoch*0.5): args.learning_rate*0.1,
               int(args.epoch*0.75): args.learning_rate*0.01}
lr_scheduler = PresetLRScheduler(lr_schedule)
# lr_scheduler = #StairCaseLRScheduler(0, args.decay_every, args.decay_ratio)

# init criterion
criterion = nn.CrossEntropyLoss()


training, inference = [], []
for epoch in range(0, args.epoch+1):

#    net.train()
#    prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), leave=True)
#    batchtimes = []
#    for batch_idx, (inputs, targets) in prog_bar:
#        inputs, targets = inputs.to(args.device), targets.to(args.device)
#        optimizer.zero_grad()
#        torch.cuda.synchronize()
#        start = time.perf_counter()
#        outputs = net(inputs)
#        loss = criterion(outputs, targets)
#        loss.backward()
#        frobdecay(net, coef=args.weight_decay, skiplist=skiplist)
#        optimizer.step()
#        torch.cuda.synchronize()
#        batchtime = time.perf_counter() - start
#        if epoch:
#            training.append(batchtime)
#        batchtimes.append(batchtime)
#    print('Training Epoch', epoch, 'Avg. Time per Batch', np.mean(batchtimes))

    net.eval()
    prog_bar = tqdm(enumerate(testloader), total=len(testloader), leave=True)
    batchtimes = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in prog_bar:
            torch.cuda.synchronize()
            start = time.perf_counter()
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            start = time.perf_counter()
            outputs = net(inputs)
            torch.cuda.synchronize()
            batchtime = time.perf_counter() - start
            if epoch:
                inference.append(batchtime)
            batchtimes.append(batchtime)
    print('Inference Epoch', epoch, 'Avg. Time per Batch', np.mean(batchtimes))
