"""Unified interface to all dynamic graph model experiments"""
import math
import logging
import os
import time
import random
import sys
import argparse

import torch
import pandas as pd
import numpy as np
from tqdm import tqdm

# import numba

from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_auc_score

from module import TGAN
from graph import NeighborFinder
from utils import EarlyStopMonitor, RandEdgeSampler
from process import simulate_dataset_train_flag


# import warnings
# def my_formatwarning(message, category, filename, lineno, line=None):
#   print(message, category)
#   # lineno is the line number you are looking for
#   print('file:', filename, 'line number:', lineno)
#   ...
# warnings.formatwarning = my_formatwarning
import warnings

warnings.filterwarnings("error")


### Argument and global variables
parser = argparse.ArgumentParser("Interface for TGAT experiments on link predictions")
parser.add_argument(
    "-d",
    "--data",
    type=str,
    help="data sources to use, try wikipedia or reddit",
    default="simulate",
)
parser.add_argument("--bs", type=int, default=64, help="batch_size")
parser.add_argument(
    "--prefix", type=str, default="", help="prefix to name the checkpoints"
)
parser.add_argument(
    "--n_degree", type=int, default=20, help="number of neighbors to sample"
)
parser.add_argument(
    "--n_head", type=int, default=2, help="number of heads used in attention layer"
)
parser.add_argument("--n_epoch", type=int, default=30, help="number of epochs")
parser.add_argument("--n_layer", type=int, default=2, help="number of network layers")
parser.add_argument("--lr", type=float, default=0.0001, help="learning rate")
parser.add_argument("--drop_out", type=float, default=0.1, help="dropout probability")
parser.add_argument("--gpu", type=int, default=0, help="idx for the gpu to use")
parser.add_argument(
    "--node_dim", type=int, default=100, help="Dimentions of the node embedding"
)
parser.add_argument(
    "--time_dim", type=int, default=100, help="Dimentions of the time embedding"
)
parser.add_argument(
    "--agg_method",
    type=str,
    choices=["attn", "lstm", "mean"],
    help="local aggregation method",
    default="attn",
)
parser.add_argument(
    "--attn_mode",
    type=str,
    choices=["prod", "map"],
    default="prod",
    help="use dot product attention or mapping based",
)
parser.add_argument(
    "--time",
    type=str,
    choices=["time", "pos", "empty"],
    help="how to use time information",
    default="time",
)
parser.add_argument(
    "--uniform",
    action="store_true",
    help="take uniform sampling from temporal neighbors",
)

try:
    args = parser.parse_args()
except:
    parser.print_help()
    sys.exit(0)

BATCH_SIZE = args.bs
NUM_NEIGHBORS = args.n_degree
NUM_NEG = 1
NUM_EPOCH = args.n_epoch
NUM_HEADS = args.n_head
DROP_OUT = args.drop_out
GPU = args.gpu
UNIFORM = args.uniform
# NEW_NODE = args.new_node
USE_TIME = args.time
AGG_METHOD = args.agg_method
ATTN_MODE = args.attn_mode
SEQ_LEN = NUM_NEIGHBORS
DATA = args.data
NUM_LAYER = args.n_layer
LEARNING_RATE = args.lr
NODE_DIM = args.node_dim
TIME_DIM = args.time_dim

# import ipdb; ipdb.set_trace()

# MODEL_SAVE_PATH = f'./saved_models/{args.data}-{args.agg_method}-{args.attn_mode}.pth'

os.makedirs("saved_models", exist_ok=True)
os.makedirs("saved_checkpoints", exist_ok=True)
MODEL_SAVE_PATH = f"./saved_models/tgat_{args.data}_best.pth"
get_checkpoint_path = (
    lambda epoch: f"./saved_checkpoints/{args.data}-{args.agg_method}-{args.attn_mode}-{epoch}.pth"
)

### set up logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
os.makedirs("log", exist_ok=True)
fh = logging.FileHandler("log/{}.log".format(str(time.time())))

fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.WARN)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
logger.info(args)


def eval_one_epoch(hint, tgan, sampler, src, dst, ts, label):
    val_acc, val_ap, val_f1, val_auc = [], [], [], []
    with torch.no_grad():
        tgan = tgan.eval()
        TEST_BATCH_SIZE = 30
        num_test_instance = len(src)
        num_test_batch = math.ceil(num_test_instance / TEST_BATCH_SIZE)
        for k in tqdm(range(num_test_batch)):
            # percent = 100 * k / num_test_batch
            # if k % int(0.2 * num_test_batch) == 0:
            #     logger.info('{0} progress: {1:10.4f}'.format(hint, percent))
            s_idx = k * TEST_BATCH_SIZE
            e_idx = min(num_test_instance - 1, s_idx + TEST_BATCH_SIZE)
            src_l_cut = src[s_idx:e_idx]
            dst_l_cut = dst[s_idx:e_idx]
            ts_l_cut = ts[s_idx:e_idx]
            label_l_cut = label[s_idx:e_idx]

            # size = len(src_l_cut)
            # src_l_fake, dst_l_fake = sampler.sample(size)

            # pos_prob, neg_prob = tgan.contrast(src_l_cut, dst_l_cut, dst_l_fake, ts_l_cut, NUM_NEIGHBORS)
            pos_prob = tgan.get_prob(src_l_cut, dst_l_cut, ts_l_cut)

            # pred_score = np.concatenate([(pos_prob).cpu().numpy(), (neg_prob).cpu().numpy()])
            pred_score = pos_prob.cpu().numpy()

            pred_label = pred_score > 0.5
            # true_label = np.concatenate([np.ones(size), np.zeros(size)])
            true_label = np.array(label_l_cut, dtype=int)

            val_acc.append((pred_label == true_label).mean())
            val_ap.append(average_precision_score(true_label, pred_score))
            # val_f1.append(f1_score(true_label, pred_label))
            # val_auc.append(roc_auc_score(true_label, pred_score))
    # import ipdb; ipdb.set_trace()

    assert len(val_acc) != 0
    # return np.mean(val_acc), np.mean(val_ap), np.mean(val_f1), np.mean(val_auc)
    return np.mean(val_acc), np.mean(val_ap), 0, 0


### Load data and train val test split
g_df = pd.read_csv("./processed/ml_{}.csv".format(DATA))
e_feat = np.load("./processed/ml_{}.npy".format(DATA))
n_feat = np.load("./processed/ml_{}_node.npy".format(DATA))

val_time, test_time = list(np.quantile(g_df.ts, [0.70, 0.85]))

src_l = g_df.u.values  # events' source node list, array
dst_l = g_df.i.values  # events' target node list, array
e_idx_l = g_df.idx.values  # event's index values, array
label_l = g_df.label.values  # event's state label, array
ts_l = g_df.ts.values  # event's time, array

# import ipdb; ipdb.set_trace()

max_src_index = src_l.max()
max_idx = max(src_l.max(), dst_l.max())

random.seed(2020)

# total_node_set = set(np.unique(np.hstack([g_df.u.values, g_df.i.values])))
# num_total_unique_nodes = len(total_node_set)

# mask_node_set = set(random.sample(set(src_l[ts_l > val_time]).union(set(dst_l[ts_l > val_time])), int(0.1 * num_total_unique_nodes)))
# mask_src_flag = g_df.u.map(lambda x: x in mask_node_set).values
# mask_dst_flag = g_df.i.map(lambda x: x in mask_node_set).values
# none_node_flag = (1 - mask_src_flag) * (1 - mask_dst_flag)

# valid_train_flag = (ts_l <= val_time) * (none_node_flag > 0)

# train_src_l = src_l[valid_train_flag]
# train_dst_l = dst_l[valid_train_flag]
# train_ts_l = ts_l[valid_train_flag]
# train_e_idx_l = e_idx_l[valid_train_flag]
# train_label_l = label_l[valid_train_flag]

# new train, test data
total_flag = simulate_dataset_train_flag(g_df)
total_src_l = src_l[total_flag]
total_dst_l = dst_l[total_flag]
total_ts_l = ts_l[total_flag]
total_e_idx_l = e_idx_l[total_flag]
total_label_l = label_l[total_flag]

train_flag = np.zeros((len(total_src_l),)).astype(bool)
train_num = int(len(train_flag) * 0.8)
train_flag[:train_num] = True
test_flag = (1 - train_flag).astype(bool)

# import ipdb; ipdb.set_trace()
train_src_l = total_src_l[train_flag]
train_dst_l = total_dst_l[train_flag]
train_ts_l = total_ts_l[train_flag]
train_e_idx_l = total_e_idx_l[train_flag]
train_label_l = total_label_l[train_flag]

test_src_l = total_src_l[test_flag]
test_dst_l = total_dst_l[test_flag]
test_ts_l = total_ts_l[test_flag]
test_e_idx_l = total_e_idx_l[test_flag]
test_label_l = total_label_l[test_flag]

# import ipdb; ipdb.set_trace()

# define the new nodes sets for testing inductiveness of the model
# train_node_set = set(train_src_l).union(train_dst_l)
# assert(len(train_node_set - mask_node_set) == len(train_node_set))
# new_node_set = total_node_set - train_node_set

# select validation and test dataset
# valid_val_flag = (ts_l <= test_time) * (ts_l > val_time)
# valid_test_flag = ts_l > test_time

# is_new_node_edge = np.array([(a in new_node_set or b in new_node_set) for a, b in zip(src_l, dst_l)])
# nn_val_flag = valid_val_flag * is_new_node_edge
# nn_test_flag = valid_test_flag * is_new_node_edge

# validation and test with all edges
# val_src_l = src_l[valid_val_flag]
# val_dst_l = dst_l[valid_val_flag]
# val_ts_l = ts_l[valid_val_flag]
# val_e_idx_l = e_idx_l[valid_val_flag]
# val_label_l = label_l[valid_val_flag]

# test_src_l = src_l[valid_test_flag]
# test_dst_l = dst_l[valid_test_flag]
# test_ts_l = ts_l[valid_test_flag]
# test_e_idx_l = e_idx_l[valid_test_flag]
# test_label_l = label_l[valid_test_flag]
# validation and test with edges that at least has one new node (not in training set)
# nn_val_src_l = src_l[nn_val_flag]
# nn_val_dst_l = dst_l[nn_val_flag]
# nn_val_ts_l = ts_l[nn_val_flag]
# nn_val_e_idx_l = e_idx_l[nn_val_flag]
# nn_val_label_l = label_l[nn_val_flag]

# nn_test_src_l = src_l[nn_test_flag]
# nn_test_dst_l = dst_l[nn_test_flag]
# nn_test_ts_l = ts_l[nn_test_flag]
# nn_test_e_idx_l = e_idx_l[nn_test_flag]
# nn_test_label_l = label_l[nn_test_flag]

### Initialize the data structure for graph and edge sampling
# build the graph for fast query
# graph only contains the training data (with 10% nodes removal)
# adj_list = [[] for _ in range(max_idx + 1)]
# for src, dst, eidx, ts in zip(train_src_l, train_dst_l, train_e_idx_l, train_ts_l):
#     adj_list[src].append((dst, eidx, ts))
#     adj_list[dst].append((src, eidx, ts))
# train_ngh_finder = NeighborFinder(adj_list, uniform=UNIFORM) # used in training

# import ipdb; ipdb.set_trace()

# full graph with all the data for the test and validation purpose
full_adj_list = [[] for _ in range(max_idx + 1)]
for src, dst, eidx, ts in zip(src_l, dst_l, e_idx_l, ts_l):
    full_adj_list[src].append((dst, eidx, ts))
    full_adj_list[dst].append((src, eidx, ts))
full_ngh_finder = NeighborFinder(
    full_adj_list, uniform=False
)  # must be False to use the limit number scheme and support event mask

# train_rand_sampler = RandEdgeSampler(train_src_l, train_dst_l)
# val_rand_sampler = RandEdgeSampler(src_l, dst_l)
# nn_val_rand_sampler = RandEdgeSampler(nn_val_src_l, nn_val_dst_l)
# test_rand_sampler = RandEdgeSampler(src_l, dst_l)
# nn_test_rand_sampler = RandEdgeSampler(nn_test_src_l, nn_test_dst_l)


### Model initialize
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
tgan = TGAN(
    full_ngh_finder,
    n_feat,
    e_feat,
    device=device,
    num_layers=NUM_LAYER,
    use_time=USE_TIME,
    agg_method=AGG_METHOD,
    attn_mode=ATTN_MODE,
    n_head=NUM_HEADS,
    drop_out=DROP_OUT,
    num_neighbors=NUM_NEIGHBORS,
)
optimizer = torch.optim.Adam(tgan.parameters(), lr=LEARNING_RATE)
criterion = torch.nn.BCELoss()
tgan = tgan.to(device)

num_instance = len(train_src_l)
num_batch = math.ceil(num_instance / BATCH_SIZE)

logger.info("num of training instances: {}".format(num_instance))
logger.info("num of batches per epoch: {}".format(num_batch))
idx_list = np.arange(num_instance)
np.random.shuffle(idx_list)

early_stopper = EarlyStopMonitor()
for epoch in range(NUM_EPOCH):
    # Training
    # training use only training graph
    # tgan.ngh_finder = train_ngh_finder
    acc, ap, f1, auc, m_loss = [], [], [], [], []
    np.random.shuffle(idx_list)
    logger.info("start {} epoch".format(epoch))

    # import ipdb; ipdb.set_trace()

    for k in tqdm(range(num_batch)):
        # percent = 100 * k / num_batch
        # if k % int(0.2 * num_batch) == 0:
        #     logger.info('progress: {0:10.4f}'.format(percent))

        s_idx = k * BATCH_SIZE
        e_idx = min(num_instance - 1, s_idx + BATCH_SIZE)
        src_l_cut, dst_l_cut = train_src_l[s_idx:e_idx], train_dst_l[s_idx:e_idx]
        ts_l_cut = train_ts_l[s_idx:e_idx]
        label_l_cut = train_label_l[s_idx:e_idx]

        label_l_tensor = torch.tensor(label_l_cut, dtype=torch.float32, device=device)

        # size = len(src_l_cut)
        # assert size != 0
        # src_l_fake, dst_l_fake = train_rand_sampler.sample(size)

        # with torch.no_grad():
        #     pos_label = torch.ones(size, dtype=torch.float, device=device)
        #     neg_label = torch.zeros(size, dtype=torch.float, device=device)

        optimizer.zero_grad()
        tgan = tgan.train()
        pos_prob = tgan.get_prob(src_l_cut, dst_l_cut, ts_l_cut)

        # loss = criterion(pos_prob, pos_label)
        # loss += criterion(neg_prob, neg_label)
        # import ipdb; ipdb.set_trace()

        loss = criterion(pos_prob, label_l_tensor)
        # import ipdb; ipdb.set_trace()

        loss.backward()
        optimizer.step()

        # get training results
        with torch.no_grad():
            tgan = tgan.eval()
            # pred_score = np.concatenate([(pos_prob).cpu().detach().numpy(), (neg_prob).cpu().detach().numpy()])
            pred_score = pos_prob.cpu().detach().numpy()

            pred_label = pred_score > 0.5
            # true_label = np.concatenate([np.ones(size), np.zeros(size)])
            true_label = np.array(label_l_cut, dtype=int)

            acc.append((pred_label == true_label).mean())
            ap.append(average_precision_score(true_label, pred_score))
            # f1.append(f1_score(true_label, pred_label))
            m_loss.append(loss.item())
            # auc.append(roc_auc_score(true_label, pred_score))

    # validation phase use all information
    # tgan.ngh_finder = full_ngh_finder
    # import ipdb; ipdb.set_trace()
    # val_acc, val_ap, val_f1, val_auc = eval_one_epoch('val for old nodes', tgan, val_rand_sampler, val_src_l,
    # val_dst_l, val_ts_l, val_label_l) #!  where warnings come from

    # nn_val_acc, nn_val_ap, nn_val_f1, nn_val_auc = eval_one_epoch('val for new nodes', tgan, val_rand_sampler, nn_val_src_l,
    # nn_val_dst_l, nn_val_ts_l, nn_val_label_l)

    test_acc, test_ap, test_f1, test_auc = eval_one_epoch(
        "val for old nodes", tgan, None, test_src_l, test_dst_l, test_ts_l, test_label_l
    )  #!  where warnings come from

    logger.info("epoch: {}:".format(epoch))
    logger.info("Epoch mean loss: {}".format(np.mean(m_loss)))
    logger.info("train acc: {}, test acc: {}".format(np.mean(acc), test_acc))
    # logger.info('train auc: {}, test auc: {}'.format(np.mean(auc), test_auc))
    logger.info("train ap: {}, test ap: {}".format(np.mean(ap), test_ap))
    # logger.info('train f1: {}, val f1: {}, new node val f1: {}'.format(np.mean(f1), val_f1, nn_val_f1))

    # if early_stopper.early_stop_check(test_acc):
    #     logger.info('No improvment over {} epochs, stop training'.format(early_stopper.max_round))
    #     logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
    #     best_model_path = get_checkpoint_path(early_stopper.best_epoch)
    #     tgan.load_state_dict(torch.load(best_model_path))
    #     logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
    #     tgan.eval()
    #     break
    # else:
    torch.save(tgan.state_dict(), get_checkpoint_path(epoch))


# testing phase use all information
# tgan.ngh_finder = full_ngh_finder
# test_acc, test_ap, test_f1, test_auc = eval_one_epoch('test for old nodes', tgan, test_rand_sampler, test_src_l,
# test_dst_l, test_ts_l, test_label_l)

# nn_test_acc, nn_test_ap, nn_test_f1, nn_test_auc = eval_one_epoch('test for new nodes', tgan, nn_test_rand_sampler, nn_test_src_l,
# nn_test_dst_l, nn_test_ts_l, nn_test_label_l)

# logger.info('Test statistics: Old nodes -- acc: {}, auc: {}, ap: {}'.format(test_acc, test_auc, test_ap))
# logger.info('Test statistics: New nodes -- acc: {}, auc: {}, ap: {}'.format(nn_test_acc, nn_test_auc, nn_test_ap))

logger.info("Saving TGAN model")
torch.save(tgan.state_dict(), MODEL_SAVE_PATH)
logger.info("TGAN models saved")
