import torch
import numpy as np
from tqdm import tqdm
import math
import logging
import random
import timeit
from .nat_utils import set_random_seed, nat_results, save_model, load_model
from .eval import eval_tgb_lpp, eval_tgb_lpp_tvi
logging.getLogger('matplotlib.font_manager').disabled = True
logging.getLogger('matplotlib.ticker').disabled = True

K_VALUE = 10  # for computing MRR

def train_val(train_val_data, model, bs, epochs, criterion, optimizer, early_stopper, 
              train_rand_sampler, logger, tgb_negative_sampler, metric, evaluator, n_hop=2, exp_seed=0):
  # unpack the data, prepare for the training
  train_data, val_data = train_val_data
  train_src_l, train_tgt_l, train_ts_l, train_e_idx_l, train_label_l = train_data
  val_src_l, val_tgt_l, val_ts_l, val_e_idx_l, val_label_l = val_data

  device = model.n_feat_th.data.device
  print("DEBUG: device: ", device)
  num_instance = len(train_src_l)  
  num_batch = math.ceil(num_instance / bs)
  logger.info('Training: num of training instances: {}'.format(num_instance))
  logger.info('Training: num of batches per epoch: {}'.format(num_batch))
  idx_list = np.arange(num_instance)

  seeds = []
  random.seed(exp_seed)
  seed = random.randint(0,100)
  train_time, val_time = [], []
  val_metric_list = []
  free_mem_l, total_mem_l, used_mem_l = [], [], []

  for epoch in range(epochs):
    train_start = timeit.default_timer()

    # seed = random.randint(0,100)
    model.set_seed(seed)
    set_random_seed(seed)
    seeds.append(seed)

    model.reset_store()
    model.reset_self_rep()

    m_loss = []
    # np.random.shuffle(idx_list)  # shuffle the training samples for every epoch
    logger.info('Start epoch {}'.format(epoch))
    for k in tqdm(range(num_batch), ncols=120, disable=True):
      # generate training mini-batch
      s_idx = k * bs
      e_idx = min(num_instance, s_idx + bs)
      
      if s_idx == e_idx:
        continue
      batch_idx = idx_list[s_idx:e_idx] # shuffle training samples for each batch
      np.random.shuffle(batch_idx)
      src_l_cut, tgt_l_cut = train_src_l[batch_idx], train_tgt_l[batch_idx]
      ts_l_cut = train_ts_l[batch_idx]
      e_l_cut = train_e_idx_l[batch_idx]
      # label_l_cut = train_label_l[batch_idx]  # currently useless since we are not predicting edge labels
      size = len(src_l_cut)
      _, bad_l_cut = train_rand_sampler.sample(size)

      # feed in the data and learn from error
      optimizer.zero_grad()
      model.train()
      pos_prob, neg_prob = model.contrast(src_l_cut, tgt_l_cut, bad_l_cut, ts_l_cut, e_l_cut)   # the core training code
      pos_label = torch.ones(size, dtype=torch.float, device=device, requires_grad=False)
      neg_label = torch.zeros(size, dtype=torch.float, device=device, requires_grad=False)
      loss = criterion(pos_prob, pos_label) + criterion(neg_prob, neg_label)
      loss.backward()
      optimizer.step()
      m_loss.append(loss.item())

    # checking GPU memory usage
    if torch.cuda.is_available():
      print("DEBUG: device: {}".format(torch.cuda.get_device_name(0)))

      # Get the memory stats
      memory_stats = torch.cuda.memory_stats()
      total_mem = torch.cuda.get_device_properties(0).total_memory
      allocated_mem = memory_stats["allocated_bytes.all.current"]
      free_mem = total_mem - allocated_mem
      used_mem = total_mem - free_mem

      logger.info("------------Epoch {}: GPU memory usage-----------".format(epoch))
      logger.info("Free memory: {}".format(free_mem))
      logger.info("Total available memory: {}".format(total_mem))
      logger.info("Used memory: {}".format(used_mem))
      logger.info("--------------------------------------------")
      free_mem_l.append(float((free_mem*1.0)/2**30))  # in GB
      total_mem_l.append(float((total_mem*1.0)/2**30))
      used_mem_l.append(float((used_mem*1.0)/2**30))
    
    train_end = timeit.default_timer()
    train_time.append(train_end - train_start)
    nat_results(logger, train_time, "train_time")

    # validation phase use all information
    val_start = timeit.default_timer()
    val_metric = eval_tgb_lpp(model, tgb_negative_sampler, val_src_l, val_tgt_l, val_ts_l, val_e_idx_l, 
                              evaluator, metric, split_mode="val", k_value=K_VALUE, bs=bs) 
    val_end = timeit.default_timer()
    
    val_time.append(val_end - val_start)
    val_metric_list.append(val_metric)

    logger.info('epoch: {}:'.format(epoch))
    logger.info('epoch mean loss: {}'.format(np.mean(m_loss)))
    logger.info('train time: {}'.format(train_end - train_start))
    logger.info('validation time: {}'.format(val_end - val_start))
    
    logger.info('validation {}: {}'.format(metric, val_metric))

    # early stop check and checkpoint saving
    if early_stopper.early_stop_check(val_metric):
      logger.info('No improvment over {} epochs, stop training'.format(early_stopper.max_round))
      logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
      load_model(model, n_hop, epoch = early_stopper.best_epoch, seed=seeds[early_stopper.best_epoch])
      logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')

      # save the best epoch model
      save_model(model, n_hop, epoch="best")

      model.eval()
      break
    else:
      save_model(model, n_hop, epoch)
  
  return train_time, val_time, val_metric_list, free_mem_l, total_mem_l, used_mem_l



def train_val_tvi(train_val_data, model, bs, epochs, criterion, optimizer, early_stopper, 
                  train_rand_sampler, logger, tgb_negative_sampler, metric, evaluator, n_hop=2,
                  nodebank=None):
  # unpack the data, prepare for the training
  train_data, val_data = train_val_data
  train_src_l, train_tgt_l, train_ts_l, train_e_idx_l, train_label_l = train_data
  val_src_l, val_tgt_l, val_ts_l, val_e_idx_l, val_label_l = val_data

  device = model.n_feat_th.data.device
  print("DEBUG: device: ", device)
  num_instance = len(train_src_l)  
  num_batch = math.ceil(num_instance / bs)
  logger.info('Training: num of training instances: {}'.format(num_instance))
  logger.info('Training: num of batches per epoch: {}'.format(num_batch))
  idx_list = np.arange(num_instance)

  seeds = []
  seed = random.randint(0,100)
  train_time, val_time = [], []
  epoch_loss = []
  val_trans_metric_list, val_induc_metric_list = [], []
  free_mem_l, total_mem_l, used_mem_l = [], [], []

  for epoch in range(epochs):
    train_start = timeit.default_timer()

    # seed = random.randint(0,100)
    model.set_seed(seed)
    set_random_seed(seed)
    seeds.append(seed)
    model.reset_store()
    model.reset_self_rep()

    m_loss = []
    # np.random.shuffle(idx_list)  # shuffle the training samples for every epoch
    logger.info('Start epoch {}'.format(epoch))
    for k in tqdm(range(num_batch), ncols=120, disable=True):
      # generate training mini-batch
      s_idx = k * bs
      e_idx = min(num_instance, s_idx + bs)
      
      if s_idx == e_idx:
        continue
      batch_idx = idx_list[s_idx:e_idx] # shuffle training samples for each batch
      np.random.shuffle(batch_idx)
      src_l_cut, tgt_l_cut = train_src_l[batch_idx], train_tgt_l[batch_idx]
      ts_l_cut = train_ts_l[batch_idx]
      e_l_cut = train_e_idx_l[batch_idx]
      label_l_cut = train_label_l[batch_idx]  # currently useless since we are not predicting edge labels
      size = len(src_l_cut)
      _, bad_l_cut = train_rand_sampler.sample(size)

      # feed in the data and learn from error
      optimizer.zero_grad()
      model.train()
      pos_prob, neg_prob = model.contrast(src_l_cut, tgt_l_cut, bad_l_cut, ts_l_cut, e_l_cut)   # the core training code
      pos_label = torch.ones(size, dtype=torch.float, device=device, requires_grad=False)
      neg_label = torch.zeros(size, dtype=torch.float, device=device, requires_grad=False)
      loss = criterion(pos_prob, pos_label) + criterion(neg_prob, neg_label)
      loss.backward()
      optimizer.step()
      m_loss.append(loss.item())

    # checking GPU memory usage
    if torch.cuda.is_available():
      print("DEBUG: device: {}".format(torch.cuda.get_device_name(0)))
      free_mem, total_mem = torch.cuda.mem_get_info()
      used_mem = total_mem - free_mem
      logger.info("------------Epoch {}: GPU memory usage-----------".format(epoch))
      logger.info("Free memory: {}".format(free_mem))
      logger.info("Total available memory: {}".format(total_mem))
      logger.info("Used memory: {}".format(used_mem))
      logger.info("--------------------------------------------")
      free_mem_l.append(float((free_mem*1.0)/2**30))  # in GB
      total_mem_l.append(float((total_mem*1.0)/2**30))
      used_mem_l.append(float((used_mem*1.0)/2**30))
    
    train_end = timeit.default_timer()
    train_time.append(train_end - train_start)
    nat_results(logger, train_time, "train_time")

    # validation phase use all information
    val_start = timeit.default_timer()
    val_trans_metric, val_induc_metric = eval_tgb_lpp_tvi(model, tgb_negative_sampler, val_src_l, val_tgt_l, val_ts_l, val_e_idx_l, 
                              evaluator, metric, split_mode="val", k_value=10, bs=bs, nodebank=nodebank) 
    val_end = timeit.default_timer()
    
    val_time.append(val_end - val_start)
    val_trans_metric_list.append(val_trans_metric)
    val_induc_metric_list.append(val_induc_metric)

    logger.info('epoch: {}:'.format(epoch))
    logger.info('epoch mean loss: {}'.format(np.mean(m_loss)))
    logger.info('train time: {}'.format(train_end - train_start))
    logger.info('validation time: {}'.format(val_end - val_start))
    
    logger.info('validation {} transductive: {}'.format(metric, val_trans_metric))
    logger.info('validation {} inductive: {}'.format(metric, val_induc_metric))
    
    if epoch == 0:
      # save things for data anaysis
      checkpoint_dir = '/'.join(model.get_checkpoint_path(0).split('/')[:-1])

    # early stop check and checkpoint saving
    if early_stopper.early_stop_check(val_trans_metric):  # NOTE: stops early based on TRANSDUCTIVE validation metric
      logger.info('No improvment over {} epochs, stop training'.format(early_stopper.max_round))
      logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
      best_checkpoint_path = model.get_checkpoint_path(early_stopper.best_epoch)
      model.load_state_dict(torch.load(best_checkpoint_path))
      best_ngh_store = []
      model.clear_store()
      for i in range(n_hop + 1):
        best_ngh_store_path = model.get_ngh_store_path(early_stopper.best_epoch, i)
        best_ngh_store.append(torch.load(best_ngh_store_path))
      model.set_neighborhood_store(best_ngh_store)
      best_self_rep_path = model.get_self_rep_path(early_stopper.best_epoch)
      best_prev_raw_path = model.get_prev_raw_path(early_stopper.best_epoch)
      best_self_rep = torch.load(best_self_rep_path)
      best_prev_raw = torch.load(best_prev_raw_path)
      model.set_self_rep(best_self_rep, best_prev_raw)
      model.set_seed(seeds[early_stopper.best_epoch])
      logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
      model.eval()
      break
    else:
      for i in range(n_hop + 1):
        torch.save(model.neighborhood_store[i], model.get_ngh_store_path(epoch, i))
      torch.save(model.state_dict(), model.get_checkpoint_path(epoch))
      torch.save(model.self_rep, model.get_self_rep_path(epoch))
      torch.save(model.prev_raw, model.get_prev_raw_path(epoch))
  
  return train_time, val_time, val_trans_metric_list, val_trans_metric_list, free_mem_l, total_mem_l, used_mem_l
