from __future__ import print_function

import os
import datetime
import argparse
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from utils import get_model, get_data_loaders, get_train_args
from torch.utils.tensorboard import SummaryWriter


def train(args):
  ROOT_PATH = args.save_path
  EXP_NAME = f'cln_{args.arch}_{args.pool_type}_{args.max_num_pools}_{args.noise_std}{args.exp_name}'
  TRAINED_MODEL_PATH = os.path.join(ROOT_PATH, f'trained_models/{args.dataset}', EXP_NAME)
  DATA_PATH = os.path.join(args.data_path, args.dataset)

  postfix = 1
  safe_path = TRAINED_MODEL_PATH
  while os.path.exists(safe_path):
    safe_path = TRAINED_MODEL_PATH + f'_{postfix}'
    postfix += 1
  TRAINED_MODEL_PATH = safe_path
  os.makedirs(TRAINED_MODEL_PATH)
  writer = SummaryWriter(TRAINED_MODEL_PATH)

  torch.manual_seed(args.seed)
  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda" if use_cuda else "cpu")

  trainargs = get_train_args(args)
  nb_epoch = trainargs['nb_epoch']
  lr_steps = trainargs['schedule_milestones']
  model_filename = f"{args.dataset}_{args.arch}_clntrained.pt"
  train_loader, test_loader = get_data_loaders(args.dataset,
                                               trainargs['train_batch_size'], trainargs['test_batch_size'],
                                               DATA_PATH,
                                               noise_std=args.noise_std)
  model = get_model(args)

  if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

  model.to(device)
  opt = optim.SGD(model.parameters(), lr=trainargs['lr'], momentum=0.9, weight_decay=trainargs['weight_decay'])
  lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, lr_steps, gamma=0.1, last_epoch=-1)

  it = 0
  for epoch in range(nb_epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
      data, target = data.to(device), target.to(device)

      opt.zero_grad()
      output = model(data)
      loss = F.cross_entropy(
        output, target, reduction='mean')
      loss.backward()
      opt.step()
      if batch_idx % args.log_interval == 0:
        print(f'{datetime.datetime.now()} Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ' 
        f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
        writer.add_scalar('Loss/C', loss.item(), it)

      it += 1

    model.eval()
    test_clnloss = 0
    clncorrect = 0

    for clndata, target in test_loader:
      clndata, target = clndata.to(device), target.to(device)
      with torch.no_grad():
        output = model(clndata)
      test_clnloss += F.cross_entropy(
        output, target, reduction='sum').item()
      pred = output.max(1, keepdim=True)[1]
      clncorrect += pred.eq(target.view_as(pred)).sum().item()

    test_clnloss /= len(test_loader.dataset)
    cln_acc = 100. * clncorrect / len(test_loader.dataset)
    print('\nTest set: avg cln loss: {:.4f},'
          ' cln acc: {}/{} ({:.0f}%)\n'.format(
        test_clnloss, clncorrect, len(test_loader.dataset),
        cln_acc))
    writer.add_scalar('Loss/C_test_cln', test_clnloss, it)
    writer.add_scalar('Test Accuracy/cln', cln_acc, it)

    if (epoch % 10 == 0) or (epoch == (nb_epoch - 1)):
      torch.save(
        model.state_dict(),
        os.path.join(TRAINED_MODEL_PATH, model_filename + f'_{epoch}.pt'))
    lr_scheduler.step()

if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Train NT and AT')
  parser.add_argument('--save_path', default='./chkpts', type=str, help='path to where to save checkpoints')
  parser.add_argument('--data_path', default='', type=str, help='path to data')
  parser.add_argument('--optimizer', default='sgd', type=str, help='sgd | adam')
  parser.add_argument('--dataset', default='cifar10', type=str, help='cifar10 | cifar100')
  parser.add_argument('--seed', default=0, type=int)
  parser.add_argument('--log_interval', default=200, type=int)
  parser.add_argument('--arch', default='resnet18', type=str)
  parser.add_argument('--pool_type', default=None, type=str)
  parser.add_argument('--max_num_pools', default=2, type=int, help='# of kernel pools to apply')
  parser.add_argument('--noise_std', default=0.2, type=float, help='Noise STD.')
  parser.add_argument('--exp_name', default='', type=str, help='experiment name')
  args = parser.parse_args()
  nt = namedtuple('nt', [*args.__dict__.keys()])
  train(nt(*args.__dict__.values()))         
