"""Train MLPs and convolutional nets on MNIST or CIFAR-10/100, using APO to
tune the global learning rate.

Example
-------
python train.py \
    --save_dir=mnist_apo \
    --dataset=mnist \
    --model=mlp \
    --seed=11 \
    --epochs=100 \
    --batch_size=100 \
    --base_optimizer=rmsprop \
    --meta_optimizer=rmsprop \
    --num_meta_steps=1 \
    --lr=1e-4 \
    --meta_lr=0.1 \
    --meta_interval=1 \
    --lam=1e-5
"""
import os
import sys
import pdb
import time
import math
import random
import logging
import datetime
import argparse
import numpy as np
from tqdm import tqdm

# YAML setup
from ruamel.yaml import YAML
yaml = YAML()
yaml.preserve_quotes = True
yaml.boolean_representation = ['False', 'True']

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

# Local imports
from csv_logger import CSVLogger

sys.path.insert(0, 'models')
import models
import optimizers
from resnet import ResNet34
from resnet_cifar import resnet32, resnet32x4
from wide_resnet import WideResNet

import utils
import data_utils
from meta_optim import APO
from schedulers import PresetLRScheduler


model_options = [
    'linear', 'mlp', 'vgg11', 'vgg13', 'vgg16', 'resnet18',
    'resnet32', 'resnet32x4', 'resnet34', 'wideresnet'
]
dataset_options = ['mnist', 'cifar10', 'cifar100', 'svhn']

parser = argparse.ArgumentParser(description='CNN')
parser.add_argument('--dataset', '-d', default='cifar10',
                    choices=dataset_options)
parser.add_argument('--model', '-a', default='resnet32',
                    choices=model_options)
parser.add_argument('--batch_size', type=int, default=128,
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=200,
                    help='number of epochs to train (default: 200)')
parser.add_argument('--lr', type=float, default=0.1,
                    help='learning rate')
parser.add_argument('--data_augmentation', action='store_true', default=False,
                    help='augment data by flipping and cropping')
parser.add_argument('--wdecay', type=float, default=0,
                    help='Weight decay applied to all weights')
parser.add_argument('--schedule', action='store_true', default=False,
                    help='Use manual learning rate decay schedule.')
parser.add_argument('--factor', type=float, default=0.2,
                    help='The multiplicative decay factor for the learning rate schedule.')
parser.add_argument('--decay_at', type=str, default='60,120,160',
                    help='The epochs at which to decay the learning rate.')
parser.add_argument('--use_val', action='store_true', default=False,
                    help='Use a separate validation split.')
parser.add_argument('--use_extra_data', action='store_true', default=False,
                    help='Use extra data for SVHN')

parser.add_argument('--base_optimizer', type=str, default='rmsprop',
                    choices=['kfac', 'sgd', 'sgdwd', 'sgdmwd', 'adam', 'rmsprop'],
                    help='Optimizer to update the main model parameters. We '
                         'tune the hyperparameters of this optimizer.')
parser.add_argument('--meta_optimizer', type=str, default=None,
                    choices=['sgd', 'rmsprop', 'adam', 'lbfgs'],
                    help='Type of optimizer to use to update hyperparameters')
parser.add_argument('--meta_lr', type=float, default=0.1,
                    help='Learning rate for the hyperparameter optimizer.')
parser.add_argument('--num_meta_steps', type=int, default=0,
                    help='Number of meta-optimization steps to take per main '
                         'parameter step.')
parser.add_argument('--meta_interval', type=int, default=10,
                    help='Perform meta-optimization every N base optimization '
                         'steps (e.g., every 1 or 10 steps)')
parser.add_argument('--lam', type=float, default=0,
                    help='Lambda weighting of the dissimilarity term (can set '
                         'to 0 to ignore dissimilarity)')

# Hyperparameters for SGD and RMSprop
parser.add_argument('--tune_lr', action='store_true', default=False,
                    help='Tune the learning rate with APO.')
parser.add_argument('--tune_rho', action='store_true', default=False,
                    help='Tune the learning rate with APO.')

parser.add_argument('--tune_lambda', action='store_true', default=False,
                    help='Use a manual schedule for lambda.')

parser.add_argument('--nlayers', type=int, default=2,
                    help='Set the number of layers for an MLP or linear model.')
parser.add_argument('--no_bn', action='store_true', default=False,
                    help='Turn off batch normalization.')

parser.add_argument('--eval_interval', type=int, default=10,
                    help='Evaluate on fixed train/validation batches every N iterations.')
parser.add_argument('--patience', type=int, default=10,
                    help='Number of epochs to wait for the training performance '
                         'to improve before early stopping.')

parser.add_argument('--seed', type=int, default=13,
                    help='Set the random seed for reprodicibility')
parser.add_argument('--disable_cuda', action='store_true', default=False,
                    help='Flag to DISABLE CUDA (which is ENABLED by default)')
parser.add_argument('--gpu', type=int, default=0,
                    help='Select which GPU to use (e.g., 0, 1, 2, or 3)')
parser.add_argument('--prefix', type=str, default=None,
                    help='Optional prefix for the experiment directory name')
parser.add_argument('--save_dir', type=str, default=None,
                    help='Base experiment directory')
parser.add_argument('--overwrite', action='store_true', default=False,
                    help='Overwrite the experiment data even if result.csv exists')

parser.add_argument('--cudaid', type=str, default='0',
                    help='which cuda to use')


args = parser.parse_args()
args.cuda = not args.disable_cuda and torch.cuda.is_available()
cudnn.benchmark = True  # Should make training should go faster for large models

cudaid_list = args.cudaid.split(',')
cudaid_list=list(range(len(cudaid_list)))

torch.set_default_dtype(torch.float32)

if args.cuda:
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  os.environ["CUDA_VISIBLE_DEVICES"] = args.cudaid

utils.print_opts(args)

if not args.disable_cuda and torch.cuda.is_available():
  use_device = torch.device('cuda:{}'.format(args.gpu))
else:
  use_device = torch.device('cpu')

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
if torch.cuda.is_available():
  torch.cuda.manual_seed(args.seed)

exp_name = 'dset:{}-model:{}-nl:{}-b:{}-m:{}-bs:{}-ilr:{}-mlr:{}-lam:{}-mstp:{}-' \
           'mint:{}-ag:{}-wd-{}-val:{}-ep:{}-fac:{}-dat:{}-seed:{}'.format(
            args.dataset, args.model, args.nlayers, args.base_optimizer,
            args.meta_optimizer, args.batch_size, args.lr, args.meta_lr,
            args.lam, args.num_meta_steps, args.meta_interval,
            int(args.data_augmentation), args.wdecay, int(args.use_val),
            args.epochs, args.factor, args.decay_at,
            args.seed if args.seed else 'None')

# Create log folder
BASE_SAVE_DIR = 'experiments'
save_dir = os.path.join(BASE_SAVE_DIR, args.save_dir, exp_name)
if not os.path.exists(save_dir):
  os.makedirs(save_dir)

# Check whether the result.csv file exists already
if os.path.exists(os.path.join(save_dir, 'result.csv')):
  if not args.overwrite:
    print('The result file {} exists! Run with --overwrite to overwrite this experiment.'.format(
          os.path.join(save_dir, 'result.csv')))
    sys.exit(0)

# Save command-line arguments
with open(os.path.join(save_dir, 'args.yaml'), 'w') as f:
  yaml.dump(vars(args), f)


# Set up logging
logging.basicConfig(
  level=logging.INFO,
  format="%(asctime)s %(message)s",
  handlers=[
      logging.FileHandler(os.path.join(save_dir, "output.log")),
      logging.StreamHandler()
  ])

logger = logging.getLogger()


if args.dataset == 'mnist':
  num_classes = 10
  train_dataloader, val_dataloader, test_dataloader = data_utils.load_mnist(
      args.batch_size, val_split=args.use_val
  )
elif args.dataset == 'fashion':
  num_classes = 10
  train_dataloader, val_dataloader, test_dataloader = data_utils.load_fashion_mnist(
      args.batch_size, val_split=args.use_val
  )
elif args.dataset == 'cifar10':
  num_classes = 10
  train_dataloader, val_dataloader, test_dataloader = data_utils.load_cifar10(
      args.batch_size, val_split=args.use_val, augmentation=args.data_augmentation
  )
elif args.dataset == 'cifar100':
  num_classes = 100
  train_dataloader, val_dataloader, test_dataloader = data_utils.load_cifar100(
      args.batch_size, val_split=args.use_val, augmentation=args.data_augmentation
  )
elif args.dataset == 'svhn':
  num_classes = 10
  train_dataloader, val_dataloader, test_dataloader = data_utils.load_svhn(
      args.batch_size, val_split=args.use_val, use_extra_data=args.use_extra_data
  )


if args.dataset == 'svhn':
  scheduler = PresetLRScheduler({80: math.log(args.lr * args.factor),
                                 120: math.log(args.lr * (args.factor**2))})
else:
  lr_schedule_dict = {}
  for (i, epoch) in enumerate([int(epoch_str) for epoch_str in args.decay_at.split(',')]):
    lr_schedule_dict[epoch] = math.log(args.lr * (args.factor**(i+1)))

  scheduler = PresetLRScheduler(lr_schedule_dict)


if args.model == 'mlp':
  model = models.MLP(
      ninp=784, nhid=1000, nout=10, nlayers=args.nlayers,
      dropout=0, use_bias=True
  )
elif args.model == 'linear':
  model = models.LinearNetwork(
      ninp=784, nout=10, nlayers=args.nlayers, use_bias=True
  )
elif args.model == 'vgg11':
  model = vgg.VGG('VGG11')
elif args.model == 'vgg13':
  model = vgg.VGG('VGG13')
elif args.model == 'vgg16':
  model = vgg.VGG('VGG16')
elif args.model == 'resnet34':
  model = ResNet34(num_classes=num_classes)
elif args.model == 'resnet32':
  model = resnet32(num_classes=num_classes)
elif args.model == 'resnet32x4':
  model = resnet32x4(num_classes=num_classes)
elif args.model == 'wideresnet':
  if args.dataset == 'svhn':
    model = WideResNet(depth=16, num_classes=num_classes, widen_factor=4, dropRate=0.4)
  else:
    model = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0)

utils.tally_parameters(model)
model = model.cuda()

def optim_parameters(model):
  return model.parameters()


tune = ['lr']
if args.tune_rho:
  tune.append('rho')

if args.base_optimizer == 'sgdmwd':
  base_optimizer = optimizers.MySGDMwd(
      model, optim_parameters, lr=args.lr, momentum=0.9, wdecay=args.wdecay, cuda=True
  )
elif args.base_optimizer == 'rmsprop':
  base_optimizer = optimizers.MyRMSprop(
      model, optim_parameters, lr=args.lr, gamma=0.9, tune=tune, weight_decay=args.wdecay
  )
elif args.base_optimizer == 'adam':
  base_optimizer = optimizers.MyAdam(
      model, optim_parameters, lr=args.lr, tune=tune, weight_decay=args.wdecay
  )

meta_optimizer = optim.RMSprop(base_optimizer.parameters(), lr=args.meta_lr, alpha=0.99)

optimizer = APO(
  model, base_optimizer, meta_optimizer, num_meta_steps=args.num_meta_steps,
  meta_interval=args.meta_interval, train_dataloader=train_dataloader,
  lam=args.lam, batch_size_prime=args.batch_size
)


hparam_fieldnames = []
for (layer, param_dict) in enumerate(base_optimizer.parameters_with_names()):
  for name in param_dict:
    hparam_fieldnames.append('{} {}'.format(layer, name))


epoch_csv_logger = CSVLogger(
    fieldnames=['epoch', 'time_elapsed', 'train_loss', 'train_acc', 'val_acc', 'test_acc'],
    filename=os.path.join(save_dir, 'epoch_log.csv')
)
iteration_csv_logger = CSVLogger(
    fieldnames=['iteration', 'train_loss', 'train_acc'] + hparam_fieldnames,
    filename=os.path.join(save_dir, 'iteration_log.csv')
)


def model_save(fname):
  torch.save(model.state_dict(), fname)


def model_load(fname):
  global model
  model.load_state_dict(torch.load(fname))


def evaluate(dataloader):
  model.eval()

  losses = []
  correct = 0.
  total = 0.

  with torch.no_grad():
    for images, labels in dataloader:
      images, labels = images.to(use_device), labels.to(use_device)

      pred = model(images)
      loss = F.cross_entropy(pred, labels)
      losses.append(loss.item())

      pred = torch.max(pred.data, 1)[1]
      total += labels.size(0)
      correct += (pred == labels).sum().item()

  accuracy = correct / total
  model.train()
  return np.mean(losses), accuracy


start_time = time.time()
global_iteration = 0
best_val_acc = 0.0
best_val_loss = 1e6

best_train_loss = 1e6
patience_elapsed = 0

val_loss, val_acc = None, None

epoch_times = []

try:
  for epoch in range(args.epochs):
    if args.schedule:
      scheduler(base_optimizer, epoch)

    xentropy_loss_avg = 0.
    correct = 0.
    total = 0.

    losses = []

    epoch_start_time = time.time()

    progress_bar = tqdm(train_dataloader)
    for i, (images, labels) in enumerate(progress_bar):
      progress_bar.set_description('Epoch ' + str(epoch))
      images, labels = images.to(use_device), labels.to(use_device)

      def loss_fn(model, updated=False):
        predictions = model(images)
        loss = F.cross_entropy(predictions, labels)
        return loss, predictions

      xentropy_loss, pred = optimizer.step(loss_fn)

      xentropy_loss_avg += xentropy_loss.item()
      losses.append(xentropy_loss.item())

      # Calculate running average of accuracy
      pred = torch.max(pred.data, 1)[1]
      total += labels.size(0)
      correct += (pred == labels.data).sum().item()
      accuracy = correct / total

      current_lr = base_optimizer.parameters_with_names()[0]['lr'].item()

      if global_iteration % args.eval_interval == 0:
        hparam_dict = {}
        for (layer, param_dict) in enumerate(base_optimizer.parameters_with_names()):
          for name in param_dict:
            fieldname = '{} {}'.format(layer, name)
            hparam_dict[fieldname] = param_dict[name].item()

        iteration_csv_logger.writerow({
            **{'iteration': global_iteration,
               'train_loss': xentropy_loss_avg / (i+1),
               'train_acc': accuracy},
            **hparam_dict
        })

        hparam_string = ' '.join('{} {}'.format(key, value) for key,value in hparam_dict.items())

      progress_bar.set_postfix(
          xentropy='{:6.4e}'.format(xentropy_loss_avg / (i+1)),
          acc='{:6.4e}'.format(accuracy),
          lr=hparam_string
      )

      global_iteration += 1

    epoch_end_time = time.time()
    epoch_elapsed_time = epoch_end_time - epoch_start_time
    epoch_times.append(epoch_elapsed_time)
    logger.info('Train epoch time: {}'.format(epoch_elapsed_time))


    train_loss = xentropy_loss_avg / (i + 1)

    if math.isnan(train_loss):
      raise KeyboardInterrupt

    if train_loss < best_train_loss:
      best_train_loss = train_loss
      patience_elapsed = 0
    else:
      patience_elapsed += 1

    if val_dataloader:
      val_loss, val_acc = evaluate(val_dataloader)
      logger.info('Val loss: {:6.4f} | Val acc: {:6.4f}'.format(val_loss, val_acc))

      if val_acc > best_val_acc:
        best_val_loss = val_loss
        best_val_acc = val_acc
        model_save(os.path.join(save_dir, 'best_val_checkpoint_file.pt'))
        logger.info('Saved new best val checkpoint!')

    test_loss, test_acc = evaluate(test_dataloader)
    logger.info('Test loss: {:6.4f} | Test acc: {:6.4f}'.format(test_loss, test_acc))

    time_elapsed = time.time() - start_time

    epoch_row = {
        'epoch': str(epoch),
        'time_elapsed': time_elapsed,
        'train_loss': train_loss,
        'train_acc': str(accuracy),
        'val_acc': str(val_acc),
        'test_acc': str(test_acc)
    }
    epoch_csv_logger.writerow(epoch_row)
except KeyboardInterrupt:
  logger.info('Exiting training early...')

epoch_csv_logger.close()
iteration_csv_logger.close()


if val_dataloader:
  # Load model with best validation performance and evaluate it on the test set
  model_load(os.path.join(save_dir, 'best_val_checkpoint_file.pt'))
  val_loss, val_acc = evaluate(val_dataloader)

test_loss, test_acc = evaluate(test_dataloader)

logger.info('Val: {} | Test: {}'.format(val_acc, test_acc))

# Record final (best) validation loss and validation accuracy in a result.csv file
result_logger = CSVLogger(
    fieldnames=['val_loss', 'val_acc', 'test_loss', 'test_acc'],
    filename=os.path.join(save_dir, 'result.csv')
)
result_logger.writerow({
    'val_loss': val_loss,
    'val_acc': val_acc,
    'test_loss': test_loss,
    'test_acc': test_acc
})

end_time = time.time()
logger.info('Avg per-epoch time: {}'.format(np.mean(epoch_times)))
logger.info('Total time: {}'.format(end_time - start_time))
