'''The adversarial training framework'''

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data.dataset import Dataset
from torch.optim.lr_scheduler import MultiStepLR
import torchvision.transforms as transforms

import h5py

import os
import sys
import time
import types
import numpy as np
import pandas as pd

from framework.wrapper import adv_train_net
from framework.config import get_config, get_arch, get_dataset, get_transform, get_pin_memory
from framework.label import label_modified

class Engine(object):
  '''
  Adversarial Training Engine
  process_trainset(trainset):       how to preprocess trainset
  process_trainset_test(trainset):  how to preprocess testset
  process_testset(testset):         how to preprocess testset
  basic_net:                  A function that returns a network architecture
                              Pass in a function instead of the network itself (you can use lambda)!
  scheduler:                  default decays 10 times at epochs 60 and 120
                              set scheduler = None will use no scheduler
                              A function that accepts an optimizer, e.g.
                              lambda opt: MultiStepLR(
                                opt, milestones=[50, 100], gamma=0.1)
  scheduler_type:             normal: as above
                              specific: lambda epoch: return lr for the epoch
  resume_scheduler:           If true, scheduler.step(start_epoch)
  continue_train:             If true, continue from checkpoint epoch. Else, start from epoch 1.
  optimizer:                  If none, then momentum SGD. Else, optimizer(net, basic_net)
  lr:                         default is 0.1
  momentum:                   default is 0.9
  wd (weight decay):          default is 5e-4
  pathdir:                    don't save by default. set this if need saving
  epochs:                     how many epochs of training? default 200
  batch_size:                 batch_size for trainloader
  criterion:                  a loss function instance, default nn.CrossEntropy
  device:                     device for training
  attack:                     if set, train_attack = test_attack = attack
  train_attack, test_attack:  attacks interested. default pgd
  test_trainset:              need to test the training set as well?
  test_freq:                  test frequency (if 0 then no test)
  test_first:                 test starts from epoch test_first.
                              test before train if test_first == 0
  save_epochs:                None or Int or Tuple.
                              None: Save after test
                              Int: Save for every n epochs
                              Tuple. Epochs when checkpoint needs to be saved
  save_optimizer:             Set true if want to save the optimizer state
  save_dict:                  Additional parameters to save in checkpoints
  shuffle_train/test:         Whether shuffle train/test set
  train_task_level:           Use which label class to train, 'class' for full label, 'sclass' for superclass, 'ssclass' for supersuperclass
  test_task_level:            Use which label class to test
  train_label:                Numbers of training labels
  test_label:                 Numbers of testing labels
  repre_root:                 path to the minibatch for representation
  
  transform_train and transform_test use default if not specified
  pytorch_seed, numpy_seed can be used to do determinstic training
  attack parameters: eps, step_size, step_num, up, down, normalize (whether / 255)
  set framework.config for default values

  How to personalize the engine?
  1. Construct the engine
     engine = Engine(...)
     This will load the dataset, perform data augmentation, build the model and load the checkpoint.

  2. Define the following two functions with registration wrappers if needed:
     @engine.register_train_function
     def train_function(epoch)

     @engine.register_test_function
     def test_function(epoch)

  3. Start the engine
     engine.start()

  4. engine.train() and engine.test() can be used to do more specific training

  Notes:
  1. For Cifar, SVHN and MNIST
    Download if not exist in root
  2. For ImageNet
    The engine would never attempt to download ImageNet
    To download it, use dataset.imagenet_download
    trainset.imgs contains a list of (image_path, target)
  '''

  def __init__(self, dataset='cifar100', root=None, process_trainset=None, process_trainset_test=None,
               process_testset=None, basic_net=None, eps=8.0, step_size=2.0, 
               step_num=7, resume_checkpoint=None, continue_train=True, scheduler='default', scheduler_type='normal',
               resume_scheduler=False, optimizer=None, lr=0.1, momentum=0.9, wd=5e-4, pathdir=None,
               epochs=200, batch_size=128, test_batch_size=128, num_workers=None, attack=None,
               criterion=None, train_attack='none', test_attack='none', test_trainset=False, transform_train=None,
               transform_test=None, device=None, test_freq=5, test_first=1, normalize=True,
               pytorch_seed=None, numpy_seed=None, save_epochs=None, save_optimizer=False,
               save_dict=None, up=1.0, down=0.0, sigma=0.5, shuffle_train=True, shuffle_test=False,
               train_task_level='sclass', test_task_level='class', train_label=5, test_label=10, 
               repre_save_path='none', repre_root='none', 
               repre_save_folder='none', cifar100_label=None):
    
    self.repre_save_folder = repre_save_folder
    self.cifar100_label=cifar100_label
    self.repre_root = repre_root
    self.repre_save_path = repre_save_path
    config = get_config()
    if root is None:
      root = config['root']
    if num_workers is None:
      num_workers = config['num_workers_{}'.format(dataset)]
    self.pytorch_seed = pytorch_seed
    self.numpy_seed = numpy_seed
    torch.manual_seed(pytorch_seed)
    np.random.seed(numpy_seed)
    
    self.dataset = dataset
    root = os.path.join(root, dataset)
    self.root = root
    pin_memory = get_pin_memory(dataset)

    # 1. Preprocess datasets
    print('==> Preparing data..')
    default_transform_train, default_transform_test = get_transform(dataset)

    self.batch_size = batch_size
    self.test_batch_size = test_batch_size
    self.transform_train = transform_train or default_transform_train
    self.transform_test = transform_test or default_transform_test
    self.train_task_level = train_task_level
    self.test_task_level = test_task_level
    self.train_label = train_label
    self.test_label = test_label
    self.test_batch_size=test_batch_size
    print(dataset)

    self.trainset, self.trainset_test, self.testset, self.num_classes = get_dataset(
        dataset, root, self.transform_train, self.transform_test)

    if process_trainset:
      process_trainset(self.trainset)
    if process_trainset_test:
      process_trainset_test(self.trainset_test)
    self.trainloader = torch.utils.data.DataLoader(
        self.trainset, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=pin_memory)
    self.trainloader_test = torch.utils.data.DataLoader(
        self.trainset_test, batch_size=test_batch_size, shuffle=shuffle_test, num_workers=num_workers, pin_memory=pin_memory)

    if process_testset:
      process_testset(self.testset)
    self.testloader = torch.utils.data.DataLoader(
        self.testset, batch_size=test_batch_size, shuffle=shuffle_test, num_workers=num_workers, pin_memory=pin_memory)

    # 2. Build model
    print('==> Building model..')

    if basic_net is None:
      basic_net = get_arch(config['basic_net_{}'.format(dataset)])
    else:
      basic_net = basic_net()
      
    # change the fully connected layer
    num_fc_ftr = basic_net.fc.in_features
    basic_net.fc = nn.Linear(num_fc_ftr, self.train_label)
    
    self.basic_net = basic_net

    default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.device = device or default_device
    self.eps = eps
    self.step_size = step_size
    self.step_num = step_num
    self.basic_net = self.basic_net.to(self.device)
    if criterion is None:
      self.criterion = nn.CrossEntropyLoss(reduction='mean')
      self.adv_net = adv_train_net(self.basic_net, self.eps,
                               self.step_size, self.step_num, self.device, normalize,
                               up=up, down=down, sigma=sigma, criterion=nn.CrossEntropyLoss)
    else:
      self.criterion = criterion(reduction='mean')
      self.adv_net = adv_train_net(self.basic_net, self.eps,
                               self.step_size, self.step_num, self.device, normalize,
                               up=up, down=down, sigma=sigma, criterion=criterion)

    if self.device == 'cuda':
      self.net = torch.nn.DataParallel(self.adv_net)
      if pytorch_seed is None or numpy_seed is None:
        cudnn.benchmark = True
        cudnn.deterministic = False
      else:
        cudnn.benchmark = False
        cudnn.deterministic = True
    else:
      self.net = self.adv_net

    # 3. Training and testing components
    self.lr = lr
    self.momentum = momentum
    self.wd = wd
    if optimizer is None:
      self.optimizer = optim.SGD(
          self.net.parameters(), lr=lr, momentum=momentum, weight_decay=wd)
    else:
      self.optimizer = optimizer(self.net, self.basic_net)
    
    if scheduler is None:
      self.scheduler = None
      self.scheduler_type = 'normal'
    elif scheduler == 'default':
      self.scheduler = MultiStepLR(
          self.optimizer, milestones=[60, 120], gamma=0.1)
      self.scheduler_type = 'normal'
    elif scheduler == 'cifar100':
      self.scheduler = MultiStepLR(
          self.optimizer, milestones=[60,120, 160], gamma=0.2)
      self.scheduler_type = 'normal'
    else:
      self.scheduler_type = scheduler_type
      if scheduler_type == 'normal':
        self.scheduler = scheduler(self.optimizer)
      else:
        self.scheduler = scheduler

    self.resume_scheduler = resume_scheduler
    self.pathdir = pathdir
    self.epochs = epochs
    self.attack = attack
    if attack is None:
      self.train_attack = train_attack
      self.test_attack = test_attack
    else:
      self.train_attack = attack
      self.test_attack = attack
      
    self.test_attack = test_attack
    
    self.test_trainset = test_trainset
    self.test_first = test_first
    self.test_freq = test_freq
    self.train_function = None
    self.test_function = None
    self.save_epochs = save_epochs
    self.save_optimizer = save_optimizer
    self.save_dict = save_dict
    self.checkpoint = None

    # 4. Resume from checkpoint if required
    self.resume_checkpoint = resume_checkpoint
    self.start_epoch = 0
    self.load_checkpoint(resume_checkpoint, continue_train=continue_train)

  def register_train_function(self, func):
    self.train_function = func
    return func

  def register_test_function(self, func):
    self.test_function = func
    return func

  def load_checkpoint(self, checkpoint, is_basic=True, continue_train=True):
    self.resume_checkpoint = checkpoint
    start_epoch = 0
    if checkpoint is not None:
      print('==> Resuming from checkpoint..')
      print(checkpoint)
      self.checkpoint = torch.load(checkpoint, map_location=self.device)

      if is_basic:
        self.basic_net.load_state_dict(self.checkpoint['net'])
      else:
        self.net.load_state_dict(self.checkpoint['net'])
      if continue_train:
        start_epoch = self.checkpoint['epoch']
    self.start_epoch = start_epoch
    if self.scheduler is not None and self.scheduler_type == 'normal':
      if self.resume_scheduler:
        self.scheduler.step(start_epoch)


  def train(self, epoch):
    print('===train(epoch={})==='.format(epoch))
    t1 = time.time()
    self.net.train()
    if self.scheduler is not None and self.scheduler_type == 'specific':
      lr = self.scheduler(epoch)
      self.optimizer.param_groups[0].update(lr=lr)

    if self.train_function:
      self.train_function(epoch)
    else:
      # Default train function
      print('lr: {}'.format(self.optimizer.param_groups[0]['lr']))
      for batch_idx, (inputs, targets) in enumerate(self.trainloader):
        for i in range(targets.shape[0]):
          # Update labels for the training set.
          if self.dataset == 'cifar10':
            targets[i] = label_modified(targets[i], self.dataset, self.train_task_level)
          elif self.dataset == 'cifar100':
            targets[i] = label_modified(targets[i], self.dataset, self.train_task_level, labelfile=self.cifar100_label)
          elif self.dataset == 'svhn':
            targets[i] = label_modified(targets[i], self.dataset, self.train_task_level)

        inputs, targets = inputs.to(self.device), targets.to(self.device)
        
        self.optimizer.zero_grad()
        outputs, pert_inputs = self.net(inputs, self.train_attack, targets)
        
        loss = self.criterion(outputs, targets)
        loss.backward()
        self.optimizer.step()

    t2 = time.time()
    print('Elapsed time: {}'.format(t2 - t1))
    if self.scheduler is not None and self.scheduler_type == 'normal':
      self.scheduler.step()
    # sys.stdout.flush()
  
  #finetune the linear classification head
  def finetune_train(self, epoch):  
    
      
    print('===fintunetrain(epoch={})==='.format(epoch))
    t1 = time.time()
    self.net.eval()
    if self.scheduler is not None and self.scheduler_type == 'specific':
      lr = self.scheduler(epoch)/10
      self.optimizer.param_groups[0].update(lr=lr)
   
    if self.train_function:
      self.train_function(epoch)
    else:
      print('lr: {}'.format(self.optimizer.param_groups[0]['lr']))
      
      for batch_idx, (inputs, targets) in enumerate(self.trainloader):
        for i in range(targets.shape[0]):
        
          if self.dataset == 'cifar10':
            targets[i] = label_modified(targets[i], self.dataset, self.test_task_level)
          elif self.dataset == 'cifar100':
            targets[i] = label_modified(targets[i], self.dataset, self.test_task_level, labelfile=self.cifar100_label)
          elif self.dataset == 'svhn':
            targets[i] = label_modified(targets[i], self.dataset, self.test_task_level)
        
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        
        self.optimizer.zero_grad()
        outputs, pert_inputs = self.net(inputs, self.test_attack, targets)
        loss = self.criterion(outputs, targets)
        loss.backward()
        self.optimizer.step()

    t2 = time.time()
    print('Elapsed time: {}'.format(t2 - t1))
    if self.scheduler is not None and self.scheduler_type == 'normal':
      self.scheduler.step()
    # sys.stdout.flush()
    
  def specific_test(self):
    print('specific test for the representation set')
    normal_test_correct = 0
    normal_test_total = 0
    with torch.no_grad():
      f = h5py.File(os.path.join(self.repre_root, 'cifar10repre.h5'),'r')   
      inputs = f['data'][:] 
      targets = f['labels'][:]                   
      f.close()
      inputs = torch.tensor(inputs)

      for i in range(targets.shape[0]):
        if self.dataset == 'cifar10':
          targets[i] = label_modified(targets[i], self.dataset, self.test_task_level)
        elif self.dataset == 'cifar100':
          targets[i] = label_modified(targets[i], self.dataset, self.test_task_level, labelfile=self.cifar100_label)
        elif self.dataset == 'svhn':
            targets[i] = label_modified(targets[i], self.dataset, self.test_task_level)

      targets = torch.tensor(targets)
      inputs, targets = inputs.to(self.device), targets.to(self.device)
      outputs, _ = self.net(inputs)
      _, predicted = outputs.max(1)
      normal_test_total += targets.size(0)
      normal_test_correct += predicted.eq(targets).sum().item()
      normal_test_acc = 100. * normal_test_correct / normal_test_total
      print('Normal test accuracy: {}'.format(normal_test_acc))
  
  def default_test(self, clean_train=True, clean_test=True, robust_train=False, robust_test=False):
    self.net.eval()
    robust_test_correct = 0
    robust_test_total = 0
    robust_train_correct = 0
    robust_train_total = 0
    normal_test_correct = 0
    normal_test_total = 0
    normal_train_correct = 0
    normal_train_total = 0

    with torch.no_grad():
      if clean_test or robust_test:
        for batch_idx, (inputs, targets) in enumerate(self.testloader):
          for i in range(targets.shape[0]):
            if self.dataset == 'cifar10':
              targets[i] = label_modified(targets[i], self.dataset, self.test_task_level)
            elif self.dataset == 'cifar100':
              targets[i] = label_modified(targets[i], self.dataset, self.test_task_level, labelfile=self.cifar100_label)
            elif self.dataset == 'svhn':
              targets[i] = label_modified(targets[i], self.dataset, self.test_task_level)
          inputs, targets = inputs.to(self.device), targets.to(self.device)
          if robust_test:
            outputs, pert_inputs = self.net(inputs, self.test_attack, targets)
            _, predicted = outputs.max(1)
            robust_test_total += targets.size(0)
            robust_test_correct += predicted.eq(targets).sum().item()

          if clean_test:
            outputs, _ = self.net(inputs)
            _, predicted = outputs.max(1)
            normal_test_total += targets.size(0)
            normal_test_correct += predicted.eq(targets).sum().item()

        if clean_test:
          normal_test_acc = 100. * normal_test_correct / normal_test_total
        if robust_test:
          robust_test_acc = 100. * robust_test_correct / robust_test_total

      if clean_train or robust_train:
        for batch_idx, (inputs, targets) in enumerate(self.trainloader_test):
          for i in range(targets.shape[0]):
            if self.dataset == 'cifar10':
              targets[i] = label_modified(targets[i], self.dataset, self.test_task_level)
            elif self.dataset == 'cifar100':
              targets[i] = label_modified(targets[i], self.dataset, self.test_task_level, labelfile=self.cifar100_label)
            elif self.dataset == 'svhn':
              targets[i] = label_modified(targets[i], self.dataset, self.test_task_level)
          inputs, targets = inputs.to(self.device), targets.to(self.device)
          if robust_train:
            outputs, pert_inputs = self.net(inputs, self.test_attack, targets)
            _, predicted = outputs.max(1)
            robust_train_total += targets.size(0)
            robust_train_correct += predicted.eq(targets).sum().item()

          if clean_train:
            outputs, _ = self.net(inputs)
            _, predicted = outputs.max(1)
            normal_train_total += targets.size(0)
            normal_train_correct += predicted.eq(targets).sum().item()

        if clean_train:
          normal_train_acc = 100. * normal_train_correct / normal_train_total
        if robust_train:
          robust_train_acc = 100. * robust_train_correct / robust_train_total

    if clean_train:
      print('Normal train accuracy: {}'.format(normal_train_acc))
    if clean_test:
      print('Normal test accuracy: {}'.format(normal_test_acc))
    if robust_train:
      print('Robust train accuracy: {}'.format(robust_train_acc))
    if robust_test:
      print('Robust test accuracy: {}'.format(robust_test_acc))

  def test(self, epoch):
    print('===test(epoch={})==='.format(epoch))
    t1 = time.time()
    self.net.eval()

    if self.test_function:
      self.test_function(epoch)
    else:
      # Default test function
      if self.test_trainset:
        self.default_test(True, True, True, True)
      else:
        self.default_test(False, True, False, True)

    t2 = time.time()
    print('Elapsed time: {}'.format(t2 - t1))


  def save_state(self, path, save_optimizer, save_dict, epoch):
    try:
      if save_optimizer:
        state = {
            'net': self.basic_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epoch': epoch,
        }
      else:
        state = {
            'net': self.basic_net.state_dict(),
            'epoch': epoch,
        }
      if save_dict is not None:
        state.update(save_dict)
      torch.save(state, path)
    except OSError:
      print('OSError while saving {}'.format(path))
      print('Ignoring...')


  def save_checkpoint(self, epoch):
    if self.pathdir is not None:
      if not os.path.isdir(self.pathdir):
        os.makedirs(self.pathdir)
      print('==> Saving {}.pth..'.format(epoch))
      self.save_state('{}/{}.pth'.format(self.pathdir, epoch), self.save_optimizer, self.save_dict, epoch)


  def start(self):
    if self.test_first == 0:
      self.test(self.start_epoch)
    for epoch in range(self.start_epoch + 1, self.epochs + 1):
      self.train(epoch)
      if self.test_freq > 0 and (epoch - self.start_epoch) % self.test_freq == 0 and epoch >= self.test_first:
        self.test(epoch)
        if self.save_epochs is None:
          self.save_checkpoint(epoch)
      if self.save_epochs is not None:
        if isinstance(self.save_epochs, int):
          if (epoch - self.start_epoch) % self.save_epochs == 0:
            self.save_checkpoint(epoch)
        else:
          assert(isinstance(self.save_epochs, (list, tuple)))
          if epoch in self.save_epochs:
            self.save_checkpoint(epoch)

  def finetunestart(self):
    new_basic_net = self.net.module.basic_net
    
    # set grad=false for original model
    for p in new_basic_net.parameters():
      p.requires_grad = False
        
    num_fc_ftr = new_basic_net.fc.in_features
    new_basic_net.fc=nn.Linear(num_fc_ftr, self.test_label)
    self.basic_net = new_basic_net
    
    # encode network again
    self.basic_net = self.basic_net.to(self.device)
    self.criterion = nn.CrossEntropyLoss(reduction='mean')
    self.adv_net = adv_train_net(self.basic_net, self.eps,
                               self.step_size, self.step_num, self.device, normalize = True,
                               up=1.0, down=0.0, sigma=0.5, criterion=nn.CrossEntropyLoss)
    
    if self.device == 'cuda':
      self.net = torch.nn.DataParallel(self.adv_net)
      if self.pytorch_seed is None or self.numpy_seed is None:
        cudnn.benchmark = True
        cudnn.deterministic = False
      else:
        cudnn.benchmark = False
        cudnn.deterministic = True
    else:
      self.net = self.adv_net
    
    
    
    self.optimizer=optim.SGD(
          self.net.module.basic_net.fc.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    
    if self.test_first == 0:
      self.test(self.start_epoch)
    for epoch in range(0,50):
      self.finetune_train(epoch)
      self.test(epoch)
      
  def output_collect(self, root):
    print('===== saving output results =====')
    outputresult=[]
    for batch_idx, (inputs, targets) in enumerate(self.testloader):
      for i in range(targets.shape[0]):
        if self.dataset == 'cifar10':
          targets[i] = label_modified(targets[i], self.dataset, self.test_task_level)
        elif self.dataset == 'cifar100':
          targets[i] = label_modified(targets[i], self.dataset, self.test_task_level, labelfile=self.cifar100_label) 
        elif self.dataset == 'svhn':
          targets[i] = label_modified(targets[i], self.dataset, self.test_task_level) 
      inputs, targets = inputs.to(self.device), targets.to(self.device)
      outputs, _ = self.net(inputs, self.test_attack, targets)
      if batch_idx == 0:
        outputresult=outputs.detach().cpu().numpy()
      else:
        outputresult=np.vstack((outputresult,outputs.detach().cpu().numpy()))
    print(outputresult.shape)
    print('===== saving final output to =====')
    data1 = pd.DataFrame(outputresult)
    data1.to_csv(root+self.repre_save_folder+'/'+self.train_task_level+' '+self.test_task_level+' '+self.repre_save_path, header = False, index = False)
    
        
  def featureout(self, root):
    print('======== extracting representation =======')
    features=[]
    for batch_idx, (inputs, targets) in enumerate(self.testloader):        
      inputs = inputs.to(self.device)
      _, _ = self.net(inputs, self.test_attack, targets)
      outputfeature = self.net.module.basic_net.feature
      outputresult = outputfeature.detach().cpu().numpy()
      outputfeat = outputresult.squeeze()
      if batch_idx == 0:
        features=outputfeat
      else:
        features=np.vstack((features, outputfeat))
    print(features.shape)
    print('representation shape; saving to')
    data1 = pd.DataFrame(features)
    data1.to_csv(root+self.repre_save_folder+'/'+self.repre_save_path, header = False, index = False)
    
