import math
import logging
import time
import sys
import argparse
import torch
import numpy as np
import pickle
from pathlib import Path
from copy import deepcopy

from evaluation.evaluation import eval_edge_prediction
from model.rtrgn import LGN
from utils.utils import EarlyStopMonitor, RandEdgeSampler, get_neighbor_finder
from utils.data_processing import get_data, compute_time_statistics

torch.manual_seed(0)
np.random.seed(0)

### Argument and global variables
parser = argparse.ArgumentParser('RTRGN self-supervised training')
parser.add_argument('-d', '--data', type=str, help='Dataset name (eg. wikipedia or reddit)',
                    default='wikipedia')
parser.add_argument('--bs', type=int, default=200, help='Batch_size')
parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints')
parser.add_argument('--n_degree', type=int, default=10, help='Number of neighbors to sample')
parser.add_argument('--n_head', type=int, default=2, help='Number of heads used in attention layer')
parser.add_argument('--n_epoch', type=int, default=50, help='Number of epochs')
parser.add_argument('--n_layer', type=int, default=1, help='Number of network layers')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--n_runs', type=int, default=1, help='Number of runs')
parser.add_argument('--drop_out', type=float, default=0.1, help='Dropout probability')
parser.add_argument('--gpu', type=int, default=0, help='Idx for the gpu to use')
parser.add_argument('--node_dim', type=int, default=100, help='Dimensions of the node embedding')
parser.add_argument('--time_dim', type=int, default=100, help='Dimensions of the time embedding')
parser.add_argument('--backprop_every', type=int, default=1, help='Every how many batches to '
                                                                  'backprop')
parser.add_argument('--use_memory', action='store_true',
                    help='Whether to augment the model with a node memory')
parser.add_argument('--embedding_module', type=str, default="graph_attention", choices=[
                    "memory_gcn_cache","memory_attention_cache",],
                    help='Type of embedding modules')
parser.add_argument('--message_function', type=str, default="identity", choices=[
  "mlp", "identity"], help='Type of message function')
parser.add_argument('--memory_updater', type=str, default="gru", choices=[
  "gru", "rnn"], help='Type of memory updater')
parser.add_argument('--aggregator', type=str, default="last", help='Type of message '
                                                                        'aggregator')
parser.add_argument('--memory_update_at_end', action='store_true',
                    help='Whether to update memory at the end or at the start of the batch')
parser.add_argument('--message_dim', type=int, default=100, help='Dimensions of the messages')
parser.add_argument('--memory_dim', type=int, default=172, help='Dimensions of the memory for '
                                                                'each user')
parser.add_argument('--different_new_nodes', action='store_true',
                    help='Whether to use disjoint set of new nodes for train and val')
parser.add_argument('--uniform', action='store_true',
                    help='take uniform sampling from temporal neighbors')
parser.add_argument('--randomize_features', action='store_true',
                    help='Whether to randomize node features')
parser.add_argument('--use_destination_embedding_in_message', action='store_true',
                    help='Whether to use the embedding of the destination node as part of the message')
parser.add_argument('--use_source_embedding_in_message', action='store_true',
                    help='Whether to use the embedding of the source node as part of the message')
parser.add_argument('--dyrep', action='store_true',
                    help='Whether to run the dyrep model')

parser.add_argument('--print_interval', type=int, default=140, help='not used. print stats')
parser.add_argument('--min_epoch', type=int, default=0, help='Min Number of epochs')
parser.add_argument('--full_lazygraph', action='store_false')

# used in both GN module
parser.add_argument('--diff_msg', action='store_true',
                    help='whether the message is differentiable in BP')
parser.add_argument('--base_memory_level', type=int, default=1, choices=[0,1],
                    help='USED IN BOTH TGN AND LGN. '
                         'default level of base memory in msg passing. '
                         '0 means using (trainable) base embedding, 1 means using a self-updated memory backbone.')

# used in only LGN
parser.add_argument('--history_limit', type=int, default=5,
                    help='USED IN LGN. default max length of history')
## time associated
parser.add_argument('--time_encoding', type=str, default="baseline",
                    choices=["baseline", "none", "copy"],
                    help='Type of time_encoding modules in the memory module (see lgn.py).'
                         'Not trained if diff_msg is false.'
                         'if copy means using the same one as passed to the embedding module.')
parser.add_argument('--unique_time_encoding', action='store_true',
                    help='Same as above. whether use same time_encoder for memory / emb module')
parser.add_argument('--remove_time', action='store_true',
                    help='whether remove time in all places besides embedding module. see lgn.py and lazay_graphemb.py')
parser.add_argument('--remove_time_in_msg', action='store_true',
                    help='whether remove time in only the msg passing (lazay_graphemb.py)')
## msg asscociated
parser.add_argument('--remove_dst_in_msg', action='store_true',
                    help='whether remove dst (self). see (lazay_graphemb.py)')
parser.add_argument('--update_cache_at_start', action='store_true',
                    help='whether update_cache_at_start with newest model even if not doing emb_postprocess.'
                         'only effects if not emb_postprocess.')
parser.add_argument('--reduce_to_base_case', action='store_true',
                    help='when compute msg whether reduce_to_base_case to the max of history length, '
                         'or else use one step cache')
parser.add_argument('--msg_from_cache', action='store_true',
                    help='whether compute msg_from_cache')
parser.add_argument('--dst_msg_asynchronous', action='store_true',
                    help='whether compute des msg using same level k')
## base embedding
parser.add_argument('--only_base', action='store_true',
                    help='whether only use base during training. see (lazay_graphemb.py)'
                         'will return the base if not forced by memory_after_emb_postprocess.')
parser.add_argument('--adapt_base', action='store_true',
                    help='whether adapt base appropriately during validation/testing.')
parser.add_argument('--base_cell', type=str, default="gru",
                    help='Type of base updater gru or rnn')
parser.add_argument('--stack_layers', action='store_true',
                    help='whether do weighted means across rnn layers.')
## postprocessing, fusing emb & mem module
parser.add_argument('--emb_postprocess', action='store_true',
                    help='USED IN LGN. whether emb_postprocessing with the embedding module')
parser.add_argument('--memory_after_emb_postprocess', action='store_true',
                    help='whether add a memory layer after emb.'
                         'only support base embedding and only support output one-layer memory.')
# embedding module
parser.add_argument('--wo_diff', action='store_true',
                    help='whether do')
parser.add_argument('--leave_out', action='store_true',
                    help='whether do')
parser.add_argument('--weighted_gnn', action='store_true',
                    help='whether do')
parser.add_argument('--rollout_ratio', type=float, default=1.0,
                    help='The fraction of nodes computed exact')
parser.add_argument('--base_lr_ratio', type=float, default=1.0,
                    help='The fraction of base_lr')

# EXPERIMENT MODE. Some Flags used in fast tuning some exps. Change should made in relevant locations.
parser.add_argument('--delete_time_or_seq', action='store_true',
                    help='Assert FALSE. operate in embedding module.')
parser.add_argument('--change_time_to_seq', action='store_true',
                    help='Assert FALSE. operate in embedding module.')
## affects cache embedding
parser.add_argument('--add_edge_feature_in_msg', action='store_true',
                    help='operate in embedding module.')
parser.add_argument('--add_node_feature_in_msg', action='store_true',
                    help='Assert FALSE in wiki and ml100k. operate in embedding module.')

try:
  args = parser.parse_args()
except:
  parser.print_help()
  sys.exit(0)

######## Condition check and processing of args
args.prefix += args.data + args.embedding_module + \
              "_" + str(args.diff_msg) + \
              "_" + str(args.base_memory_level) + \
              "_" + str(args.time_encoding) + \
              "_" + str(args.only_base) + \
              "_" + str(args.weighted_gnn) + \
              "_" + str(args.n_layer) + \
              "_" + str(100*args.rollout_ratio)
print(args.prefix)
# if args.msg_from_cache:
#   assert not args.diff_msg
if "cache" in args.embedding_module:
  cache_in_embedding_module = True
  print("Using cache in embedding module.")
else:
  cache_in_embedding_module = False
if not args.diff_msg:
  print("WARNING: No gradient bp to time encoder in the memory module.")
else:
  if args.msg_from_cache:
    print("WARNING: diff_msg but msg_from_cache.")
if args.memory_after_emb_postprocess:
  assert (args.only_base and args.adapt_base and \
          args.emb_postprocess and args.msg_from_cache and \
          args.time_encoding=="none" and args.base_memory_level==0)
if args.base_memory_level != 0:
  print("WARNING: Using aligned memory propagation. (No used base embedding and now only use the last layer)")
for arg in vars(args):
  print(arg, getattr(args, arg))
######################################################## finish

######### List of Experiment Settings ##############
####################################################

BATCH_SIZE = args.bs
NUM_NEIGHBORS = args.n_degree
NUM_NEG = 1
NUM_EPOCH = args.n_epoch
NUM_HEADS = args.n_head
DROP_OUT = args.drop_out
GPU = args.gpu
DATA = args.data
NUM_LAYER = args.n_layer
LEARNING_RATE = args.lr
NODE_DIM = args.node_dim
TIME_DIM = args.time_dim
USE_MEMORY = args.use_memory
MESSAGE_DIM = args.message_dim
MEMORY_DIM = args.memory_dim
# print(MESSAGE_DIM,MEMORY_DIM)

Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}.pth'
get_checkpoint_path = lambda \
    epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}.pth'
### set up logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
Path("log/").mkdir(parents=True, exist_ok=True)
fh = logging.FileHandler('log/{}.log'.format(str(time.time())))
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.WARN)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
logger.info(args)

### Extract data for training, validation and testing
node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, \
new_node_test_data = get_data(DATA,
                              different_new_nodes_between_val_and_test=args.different_new_nodes, randomize_features=args.randomize_features)

train_ngh_finder = get_neighbor_finder(train_data, args.uniform)
# Initialize validation and test neighbor finder to retrieve temporal graph
full_ngh_finder = get_neighbor_finder(full_data, args.uniform)

# Initialize negative samplers. Set seeds for validation and testing so negatives are the same
# across different runs
# NB: in the inductive setting, negatives are sampled only amongst other new nodes
train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations)
val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0)
nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations,
                                      seed=1)
test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2)
nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources,
                                       new_node_test_data.destinations,
                                       seed=3)

# Set device
device_string = 'cuda:{}'.format(GPU) if (torch.cuda.is_available() and GPU>=0) else 'cpu'
device = torch.device(device_string)
print(device)
# Compute time statistics
mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \
  compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps)




def backup_flow():
  memory_backup, embedding_memory_backup, memory_backup2 = None, None, None
  memory_backup = rtrgn.lge.backup_memory()
  if cache_in_embedding_module:
    embedding_memory_backup = rtrgn.embedding_module.backup_memory()
  return memory_backup, embedding_memory_backup, memory_backup2

def restore_flow(memory_backup, embedding_memory_backup, memory_backup2, with_debug=True):
  if not args.only_base:
    if with_debug:
      rtrgn.lge.debug()
  if with_debug:
    rtrgn.embedding_module.debug()
  rtrgn.lge.restore_memory(memory_backup)
  if cache_in_embedding_module:
    rtrgn.embedding_module.restore_memory(embedding_memory_backup)



for i in range(args.n_runs):
  results_path = "results/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/{}.pkl".format(args.prefix)
  Path("results/").mkdir(parents=True, exist_ok=True)
  # Initialize Model
  rtrgn = LGN(neighbor_finder=train_ngh_finder, node_features=node_features,
            edge_features=edge_features, device=device,
            n_layers=NUM_LAYER,
            n_heads=NUM_HEADS, dropout=DROP_OUT, use_memory=USE_MEMORY,
            message_dimension=MESSAGE_DIM, memory_dimension=MEMORY_DIM,
            memory_update_at_start=not args.memory_update_at_end,
            embedding_module_type=args.embedding_module,
            message_function=args.message_function,
            aggregator_type=args.aggregator,
            memory_updater_type=args.memory_updater,
            n_neighbors=NUM_NEIGHBORS,
            mean_time_shift_src=mean_time_shift_src, std_time_shift_src=std_time_shift_src,
            mean_time_shift_dst=mean_time_shift_dst, std_time_shift_dst=std_time_shift_dst,
            use_destination_embedding_in_message=args.use_destination_embedding_in_message,
            use_source_embedding_in_message=args.use_source_embedding_in_message,
            dyrep=args.dyrep,
            args=args)
  criterion = torch.nn.BCELoss()
  for p in rtrgn.parameters():
      if p.requires_grad:
          print(p.name, p.shape) #, p.data)
  optimizer = torch.optim.Adam(rtrgn.parameters(), lr=LEARNING_RATE) #*args.rollout_ratio)
  if args.adapt_base:
    LEARNING_RATE_BASE = args.base_lr_ratio*LEARNING_RATE
    optimizer_base_emb = torch.optim.Adam([rtrgn.lge.base_embedding], lr=LEARNING_RATE_BASE)
    optimizer_base_emb_on = torch.optim.Adam([rtrgn.lge.base_embedding], lr=LEARNING_RATE_BASE)
    optimizer_base_emb_testornn = torch.optim.Adam([rtrgn.lge.base_embedding], lr=LEARNING_RATE_BASE)
    # print(optimizer_base_emb.param_groups)
    assert optimizer.param_groups[0]["params"][2].data.shape[0] == node_features.shape[0] + 1
    assert optimizer.param_groups[0]["params"][2].data.shape[1] == node_features.shape[1]
    del optimizer.param_groups[0]["params"][2]
  else:
    optimizer_base_emb, optimizer_base_emb_on, optimizer_base_emb_testornn = None, None, None
  rtrgn = rtrgn.to(device)

  num_instance = len(train_data.sources)
  num_batch = math.ceil(num_instance / BATCH_SIZE)

  logger.info('num of training instances: {}'.format(num_instance))
  logger.info('num of batches per epoch: {}'.format(num_batch))
  idx_list = np.arange(num_instance)

  new_nodes_val_aps = []
  val_aps = []
  epoch_times = []
  total_epoch_times = []
  train_losses = []

  early_stopper = EarlyStopMonitor(max_round=args.patience)
  for epoch in range(NUM_EPOCH):
    if cache_in_embedding_module:
      # print(np.mean(rtrgn.embedding_module.cache_var[:,0:num_batch], axis=1))
      rtrgn.embedding_module.cache_var=np.zeros([NUM_LAYER,10000])
      rtrgn.embedding_module.current_iter=0

    start_epoch = time.time()
    ### Training

    if USE_MEMORY:
      if epoch>0:
        restore_flow(train_memory_backup, embedding_train_memory_backup, train_memory_backup2)
      rtrgn.lge.__init_cache_var__() # Reinitialize memory of the model at the start of each epoch
      if args.adapt_base and epoch>0:
        rtrgn.lge.restore_memory(train_memory_backup)

    # Train using only training graph
    rtrgn.set_neighbor_finder(train_ngh_finder)
    m_loss = []

    logger.info('start {} epoch'.format(epoch))
    for k in range(0, num_batch, args.backprop_every):
      loss = 0
      optimizer.zero_grad()
      if args.adapt_base:
        optimizer_base_emb.zero_grad()
        optimizer_base_emb_on.zero_grad()
        optimizer_base_emb_testornn.zero_grad()

      # Custom loop to allow to perform backpropagation only every a certain number of batches
      for j in range(args.backprop_every):
        batch_idx = k + j

        if batch_idx >= num_batch:
          continue

        start_idx = batch_idx * BATCH_SIZE
        end_idx = min(num_instance, start_idx + BATCH_SIZE)
        sources_batch, destinations_batch = train_data.sources[start_idx:end_idx], \
                                            train_data.destinations[start_idx:end_idx]
        edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx]
        timestamps_batch = train_data.timestamps[start_idx:end_idx]
        # timestamps_batch = start_idx + np.arange(end_idx-start_idx) # exp about time or seq
        # print(timestamps_batch)

        size = len(sources_batch)
        _, negatives_batch = train_rand_sampler.sample(size)

        with torch.no_grad():
          pos_label = torch.ones(size, dtype=torch.float, device=device)
          neg_label = torch.zeros(size, dtype=torch.float, device=device)

        rtrgn = rtrgn.train()
        pos_prob, neg_prob = rtrgn.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch,
                                                            timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS)

        loss += criterion(pos_prob.squeeze(), pos_label) + criterion(neg_prob.squeeze(), neg_label)
      loss /= args.backprop_every
      loss.backward()

      optimizer.step()
      if args.adapt_base:
        optimizer_base_emb.step()
      m_loss.append(loss.item())

      # Detach memory after 'args.backprop_every' number of batches so we don't backpropagate to the start of time
      if USE_MEMORY:
        rtrgn.lge.detach_memory()

    epoch_time = time.time() - start_epoch
    epoch_times.append(epoch_time)

    ### Validation
    # Validation uses the full graph
    rtrgn.set_neighbor_finder(full_ngh_finder)
    if USE_MEMORY:
      # Backup memory at the end of training, so later we can restore it and use it for the
      # validation on unseen nodes
      train_memory_backup, embedding_train_memory_backup, train_memory_backup2 = backup_flow()
    if args.adapt_base: # old_node use train state
      # print("[debug] before val",optimizer_base_emb.state_dict()) # pass
      optimizer_base_emb_on.load_state_dict(optimizer_base_emb.state_dict())
    val_ap, val_auc = eval_edge_prediction(model=rtrgn,
                                            negative_edge_sampler=val_rand_sampler,
                                            data=val_data,
                                            n_neighbors=NUM_NEIGHBORS,
                                           device=device,
                                           train_base_embedding=args.adapt_base,
                                           base_embedding=rtrgn.lge.base_embedding,
                                           lr=args.lr,
                                           optimizer = optimizer_base_emb_on)
    # if args.adapt_base:
    #   print("[debug] after val on on state", optimizer_base_emb_on.state_dict()) # pass, different from above
    if USE_MEMORY:
      val_memory_backup, embedding_val_memory_backup, val_memory_backup2 = backup_flow()
      # Restore memory we had at the end of training to be used when validating on new nodes.
      # Also backup memory after validation so it can be used for testing (since test edges are strictly later in time than validation edges)
      restore_flow(train_memory_backup,embedding_train_memory_backup,train_memory_backup2)
    # Validate on unseen nodes
    if args.adapt_base: # new node use train state
      optimizer_base_emb_testornn.load_state_dict(optimizer_base_emb.state_dict())
    nn_val_ap, nn_val_auc = eval_edge_prediction(model=rtrgn,
                                                negative_edge_sampler=val_rand_sampler,
                                                data=new_node_val_data,
                                                n_neighbors=NUM_NEIGHBORS,
                                               device=device,
                                               train_base_embedding=args.adapt_base,
                                               base_embedding=rtrgn.lge.base_embedding,
                                               lr=args.lr,
                                               optimizer = optimizer_base_emb_testornn) # deepcopy(optimizer_base_emb))

    if USE_MEMORY:      # Restore memory we had at the end of validation
      restore_flow(val_memory_backup, embedding_val_memory_backup, val_memory_backup2, with_debug=False)
    new_nodes_val_aps.append(nn_val_ap)
    val_aps.append(val_ap)
    train_losses.append(np.mean(m_loss))

    # Save temporary results to disk
    pickle.dump({
      "val_aps": val_aps,
      "new_nodes_val_aps": new_nodes_val_aps,
      "train_losses": train_losses,
      "epoch_times": epoch_times,
      "total_epoch_times": total_epoch_times
    }, open(results_path, "wb"))

    total_epoch_time = time.time() - start_epoch
    total_epoch_times.append(total_epoch_time)

    logger.info('epoch: {} took {:.2f}s'.format(epoch, total_epoch_time))
    logger.info('Epoch mean loss: {}'.format(np.mean(m_loss)))
    logger.info(
      'val auc: {}, new node val auc: {}'.format(val_auc, nn_val_auc))
    logger.info(
      'val ap: {}, new node val ap: {}'.format(val_ap, nn_val_ap))

    # Early stopping
    if early_stopper.early_stop_check(val_ap):
      if epoch > args.min_epoch:
        logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
        logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
        best_model_path = get_checkpoint_path(early_stopper.best_epoch)
        rtrgn.load_state_dict(torch.load(best_model_path))
        logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
        rtrgn.eval()
        break
      else:
        print("WARNING! MODEL NOT SAVED.")
    else:
      torch.save(rtrgn.state_dict(), get_checkpoint_path(epoch))

  # Training has finished, we have loaded the best model, and we want to backup its current
  # memory (which has seen validation edges) so that it can also be used when testing on unseen
  # nodes
  if USE_MEMORY:
    val_memory_backup, embedding_val_memory_backup, val_memory_backup2 = backup_flow()
    if args.adapt_base: # test use val old node state
      optimizer_base_emb_testornn.load_state_dict(optimizer_base_emb_on.state_dict())
  rtrgn.embedding_module.neighbor_finder = full_ngh_finder
  test_ap, test_auc = eval_edge_prediction(model=rtrgn,
                                          negative_edge_sampler=test_rand_sampler,
                                          data=test_data,
                                          n_neighbors=NUM_NEIGHBORS,
                                           device=device,
                                           train_base_embedding=args.adapt_base,
                                           base_embedding=rtrgn.lge.base_embedding,
                                           lr=args.lr,
                                           optimizer = optimizer_base_emb_testornn)

  if USE_MEMORY:
    restore_flow(val_memory_backup, embedding_val_memory_backup, val_memory_backup2, with_debug=False)
    if args.adapt_base: # test use val old node state
      optimizer_base_emb_testornn.load_state_dict(optimizer_base_emb_on.state_dict())
  # Test on unseen nodes
  nn_test_ap, nn_test_auc = eval_edge_prediction(model=rtrgn,
                                                negative_edge_sampler=nn_test_rand_sampler,
                                                data=new_node_test_data,
                                                n_neighbors=NUM_NEIGHBORS,
                                                 device=device,
                                                 train_base_embedding=args.adapt_base,
                                                 base_embedding=rtrgn.lge.base_embedding,
                                                 lr=args.lr,
                                                 optimizer = optimizer_base_emb_testornn)

  logger.info(
    'Test statistics: Old nodes -- auc: {}, ap: {}'.format(test_auc, test_ap))
  logger.info(
    'Test statistics: New nodes -- auc: {}, ap: {}'.format(nn_test_auc, nn_test_ap))
  # Save results for this run
  pickle.dump({
    "val_aps": val_aps,
    "new_nodes_val_aps": new_nodes_val_aps,
    "test_ap": test_ap,
    "new_node_test_ap": nn_test_ap,
    "epoch_times": epoch_times,
    "train_losses": train_losses,
    "total_epoch_times": total_epoch_times
  }, open(results_path, "wb"))

  if USE_MEMORY:
    restore_flow(val_memory_backup, embedding_val_memory_backup, val_memory_backup2, with_debug=False)
  # torch.save(rtrgn.state_dict(), MODEL_SAVE_PATH)
  logger.info('Model saved.')



