import math
from tqdm import tqdm
import numpy as np
import torch
import time
import logging


from utils.attack import nat_attack


def eval_tgb_lpp(model, tgb_negative_sampler, src, tgt, ts, e_id, 
                 evaluator, metric: str = 'mrr', split_mode: str = 'test', k_value: int = 10, 
                 bs=32, attack_params=None, device=torch.device("cpu")):
  """
  Evaluate a NAT model for link property prediction with TGB evaluator
  """
  perf_list = []
  with torch.no_grad():
    model = model.eval()
    TEST_BATCH_SIZE = bs
    num_test_instance = len(src)
    # 
    b_max = math.ceil(num_test_instance / bs)
    b_min = 0

    if attack_params : #att_pos=front
        attack_type = attack_params.attack_type
        budget = attack_params.adv_budget
        attack_times = []
        

    for k in tqdm(range(b_min, b_max), ncols=120, disable=True):
      # there was an issue with `tgbl-wiki`, so lets skip that one:
      if tgb_negative_sampler.dataset_name == 'tgbl-wiki' and split_mode == 'test' and k == 118:  # there is a bug with this iteration!!!
        continue
        
      # normal case:
      s_idx = k * bs
      e_idx = min(num_test_instance, s_idx + bs)
      if s_idx == e_idx:
        continue
      batch_idx = np.arange(s_idx, e_idx)
      np.random.shuffle(batch_idx)
      src_l_cut = src[batch_idx]
      tgt_l_cut = tgt[batch_idx]
      ts_l_cut = ts[batch_idx]
      e_l_cut = e_id[batch_idx] if (e_idx is not None) else None

      if 'tgb' in  tgb_negative_sampler.dataset_name:  # this is due to specific data preprocessing of NAT!!!
          src_l_cut_orig = src_l_cut - 1
          tgt_l_cut_orig = tgt_l_cut - 1
      else:
          src_l_cut_orig = src_l_cut
          tgt_l_cut_orig = tgt_l_cut

      pos_t = np.array([int(ts) for ts in ts_l_cut])
      if attack_params and (n_adv_edges:= round(len(pos_t) * budget)) > 0: #att_pos=front
            pos_batch = src_l_cut, tgt_l_cut, pos_t, e_l_cut
            t1 = time.time()
            logging.info("Attacking")
            attack_params = nat_attack(model, pos_batch, attack_type, n_adv_edges, attack_params, device=device)
            attack_times.append(time.time()-t1)
            logging.info(f"Attacked in {attack_times[-1]}")

      

      neg_batch_list = tgb_negative_sampler.query_batch(src_l_cut_orig, tgt_l_cut_orig, pos_t, split_mode=split_mode)

      for idx, neg_batch in enumerate(neg_batch_list):
          neg_batch = np.array(neg_batch) + 1  # due to the special data loading processing ...
          batch_neg_src_node_ids = np.array([int(src_l_cut[idx]) for _ in range(len(neg_batch))])
          batch_neg_dst_node_ids = np.array(neg_batch)
          batch_neg_node_interact_times = np.array([ts_l_cut[idx] for _ in range(len(neg_batch))])

          
          # attack ?
          # negative edges
          negative_probabilities = model.contrast_modified(batch_neg_src_node_ids, batch_neg_dst_node_ids, 
                                                          batch_neg_node_interact_times, e_idx_l=None, 
                                                          pos_edge=False, test=True)  
                  
          # one positive edge
          positive_probabilities = model.contrast_modified(np.array([src_l_cut[idx]]), np.array([tgt_l_cut[idx]]),
                                                          np.array([ts_l_cut[idx]]), np.array([e_l_cut[idx]]),
                                                          pos_edge=True, test=True)
          
          # compute MRR
          input_dict = {
              'y_pred_pos': np.array(positive_probabilities.cpu()),
              'y_pred_neg': np.array(negative_probabilities.cpu()),
              'eval_metric': [metric]
          }
          perf_list.append(evaluator.eval(input_dict)[metric])

  avg_perf_metric = float(np.mean(np.array(perf_list)))
  return avg_perf_metric



def eval_tgb_lpp_old(model, tgb_negative_sampler, val_src_l, val_tgt_l, val_ts_l, val_e_idx_l, 
                              evaluator, metric, split_mode, k_value, bs):
    
    return
    num_instance = len(val_src_l)  
    num_batch = math.ceil(num_instance / bs)
    idx_list = np.arange(num_instance)

    for k in tqdm(range(num_batch), ncols=120, disable=True):
        # generate training mini-batch
        s_idx = k * bs
        e_idx = min(num_instance, s_idx + bs)
        
        if s_idx == e_idx:
            continue
        batch_idx = idx_list[s_idx:e_idx] # shuffle training samples for each batch
        np.random.shuffle(batch_idx)
        pos_src, pos_dst = val_src_l[batch_idx], val_tgt_l[batch_idx]
        pos_t = val_ts_l[batch_idx]
        e_l_cut = val_e_idx_l[batch_idx]
        
        _, neg_batch_list = tgb_negative_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)

        
        # feed in the data and learn from error
        model.eval()
      
        for idx, neg_batch in enumerate(neg_batch_list):
            src = torch.full((1 + len(neg_batch),), pos_src[idx], device=device)
            dst = torch.tensor(
                np.concatenate(
                    ([np.array([pos_dst.cpu().numpy()[idx]]), np.array(neg_batch)]),
                    axis=0,
                ),
                device=device,
            )

            y_pred = model(full_data, src, dst, assoc)

            # compute MRR
            input_dict = {
                "y_pred_pos": np.array([y_pred[0, :].squeeze(dim=-1).cpu()]),
                "y_pred_neg": np.array(y_pred[1:, :].squeeze(dim=-1).cpu()),
                "eval_metric": [metric],
            }
            perf_list.append(evaluator.eval(input_dict)[metric])
            
            pos_prob, neg_prob = model.contrast(src_l_cut, tgt_l_cut, bad_l_cut, ts_l_cut, e_l_cut)   # the core training code



def eval_tgb_lpp_tvi(model, tgb_negative_sampler, src, tgt, ts, e_id, 
                 evaluator, metric: str = 'mrr', split_mode: str = 'test', k_value: int = 10, 
                 bs=32, nodebank=None):
  """
  Evaluate a NAT model for link property prediction with TGB evaluator
  """
  assert nodebank != None, "NodeBank should be intialized for Transductive vs. Inductive evaluation!"
  trans_perf_list, induc_perf_list = [], []
  with torch.no_grad():
    model = model.eval()
    TEST_BATCH_SIZE = bs
    num_test_instance = len(src)
    # 
    b_max = math.ceil(num_test_instance / TEST_BATCH_SIZE)
    b_min = 0
    for k in tqdm(range(b_min, b_max), ncols=120, disable=True):
      if tgb_negative_sampler.dataset_name == 'tgbl-wiki' and split_mode == 'test' and k == 118:  # there is a bug with this iteration!!!
        continue
        
      # normal case:
      s_idx = k * TEST_BATCH_SIZE
      e_idx = min(num_test_instance, s_idx + TEST_BATCH_SIZE)
      if s_idx == e_idx:
        continue
      batch_idx = np.arange(s_idx, e_idx)
      np.random.shuffle(batch_idx)
      src_l_cut = src[batch_idx]
      tgt_l_cut = tgt[batch_idx]
      ts_l_cut = ts[batch_idx]
      e_l_cut = e_id[batch_idx] if (e_idx is not None) else None

      src_l_cut_orig = src_l_cut - 1
      tgt_l_cut_orig = tgt_l_cut - 1
      pos_t = np.array([int(ts) for ts in ts_l_cut])

      neg_batch_list = tgb_negative_sampler.query_batch(src_l_cut_orig, tgt_l_cut_orig, pos_t, split_mode=split_mode)

      for idx, neg_batch in enumerate(neg_batch_list):
          neg_batch = np.array(neg_batch) + 1  # due to the special data loading processing ...
          batch_neg_src_node_ids = np.array([int(src_l_cut[idx]) for _ in range(len(neg_batch))])
          batch_neg_dst_node_ids = np.array(neg_batch)
          batch_neg_node_interact_times = np.array([e_l_cut[idx] for _ in range(len(neg_batch))])

          # negative edges
          negative_probabilities = model.contrast_modified(batch_neg_src_node_ids, batch_neg_dst_node_ids, 
                                                          batch_neg_node_interact_times, e_idx_l=None, 
                                                          pos_edge=False, test=True)  
                  
          # one positive edge
          positive_probabilities = model.contrast_modified(np.array([src_l_cut[idx]]), np.array([tgt_l_cut[idx]]),
                                                          np.array([ts_l_cut[idx]]), np.array([e_l_cut[idx]]),
                                                          pos_edge=True, test=True)
          
          # compute MRR
          input_dict = {
              'y_pred_pos': np.array(positive_probabilities.cpu()),
              'y_pred_neg': np.array(negative_probabilities.cpu()),
              'eval_metric': [metric]
          }
          if nodebank.query_node(src_l_cut[idx]):
            trans_perf_list.append(evaluator.eval(input_dict)[metric])
          else:
            induc_perf_list.append(evaluator.eval(input_dict)[metric])
          
  avg_trans_perf_metric = float(np.mean(np.array(trans_perf_list)))
  avg_induc_perf_metric = float(np.mean(np.array(induc_perf_list)))
  return avg_trans_perf_metric, avg_induc_perf_metric

