import math
import logging
import time
import sys
import random
import argparse
import pickle
from pathlib import Path

import torch
print("HEREEEEEE")
print(torch.__version__)
import numpy as np

from model.tgn import TGN
from utils.utils import EarlyStopMonitor, get_neighbor_finder, MLP
from utils.data_processing import compute_time_statistics, get_data_node_classification
from evaluation.evaluation import eval_node_classification

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

### Argument and global variables
parser = argparse.ArgumentParser('TGN self-supervised training')
parser.add_argument('-d', '--data', type=str, help='Dataset name (eg. wikipedia or reddit)',
										default='wikipedia')
parser.add_argument('--bs', type=int, default=100, help='Batch_size')
parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints')
parser.add_argument('--n_degree', type=int, default=10, help='Number of neighbors to sample')
parser.add_argument('--n_head', type=int, default=2, help='Number of heads used in attention layer')
parser.add_argument('--n_epoch', type=int, default=10, help='Number of epochs')
parser.add_argument('--n_layer', type=int, default=1, help='Number of network layers')
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--n_runs', type=int, default=1, help='Number of runs')
parser.add_argument('--drop_out', type=float, default=0.1, help='Dropout probability')
parser.add_argument('--gpu', type=int, default=0, help='Idx for the gpu to use')
parser.add_argument('--node_dim', type=int, default=100, help='Dimensions of the node embedding')
parser.add_argument('--time_dim', type=int, default=100, help='Dimensions of the time embedding')
parser.add_argument('--backprop_every', type=int, default=1, help='Every how many batches to '
																																	'backprop')
parser.add_argument('--use_memory', action='store_true',
										help='Whether to augment the model with a node memory')
parser.add_argument('--embedding_module', type=str, default="graph_attention", choices=[
	"graph_attention", "graph_sum", "identity", "time"], help='Type of embedding module')
parser.add_argument('--message_function', type=str, default="identity", choices=[
	"mlp", "identity"], help='Type of message function')
parser.add_argument('--aggregator', type=str, default="last", help='Type of message '
																																				'aggregator')
parser.add_argument('--memory_update_at_end', action='store_true',
										help='Whether to update memory at the end or at the start of the batch')
parser.add_argument('--message_dim', type=int, default=100, help='Dimensions of the messages')
parser.add_argument('--memory_dim', type=int, default=172, help='Dimensions of the memory for '
																																'each user')
parser.add_argument('--different_new_nodes', action='store_true',
										help='Whether to use disjoint set of new nodes for train and val')
parser.add_argument('--uniform', action='store_true',
										help='take uniform sampling from temporal neighbors')
parser.add_argument('--randomize_features', action='store_true',
										help='Whether to randomize node features')
parser.add_argument('--use_destination_embedding_in_message', action='store_true',
										help='Whether to use the embedding of the destination node as part of the message')
parser.add_argument('--use_source_embedding_in_message', action='store_true',
										help='Whether to use the embedding of the source node as part of the message')
parser.add_argument('--n_neg', type=int, default=1)
parser.add_argument('--use_validation', action='store_true',
										help='Whether to use a validation set')
parser.add_argument('--new_node', action='store_true', help='model new node')

try:
	args = parser.parse_args()
except:
	parser.print_help()
	sys.exit(0)

BATCH_SIZE = args.bs
NUM_NEIGHBORS = args.n_degree
NUM_NEG = 1
NUM_EPOCH = args.n_epoch
NUM_HEADS = args.n_head
DROP_OUT = args.drop_out
GPU = args.gpu
UNIFORM = args.uniform
NEW_NODE = args.new_node
SEQ_LEN = NUM_NEIGHBORS
DATA = args.data
NUM_LAYER = args.n_layer
LEARNING_RATE = args.lr
NODE_LAYER = 1
NODE_DIM = args.node_dim
TIME_DIM = args.time_dim
USE_MEMORY = args.use_memory
MESSAGE_DIM = args.message_dim
MEMORY_DIM = args.memory_dim

Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}' + '\
	node-classification.pth'
get_checkpoint_path = lambda \
		epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}' + '\
	node-classification.pth'

### set up logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler('log/{}.log'.format(str(time.time())))
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.WARN)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
logger.info(args)

full_data, node_features, edge_features, train_data, val_data, test_data = \
	get_data_node_classification(DATA, use_validation=args.use_validation)

max_idx = max(full_data.unique_nodes)

train_ngh_finder = get_neighbor_finder(train_data, uniform=UNIFORM, max_node_idx=max_idx)

# Set device
device_string = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_string)

# Compute time statistics
mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \
	compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps)

for i in range(args.n_runs):
	results_path = "results/{}_node_classification_{}.pkl".format(args.prefix,
																																i) if i > 0 else "results/{}_node_classification.pkl".format(
		args.prefix)
	Path("results/").mkdir(parents=True, exist_ok=True)

	# Initialize Model
	tgn = TGN(neighbor_finder=train_ngh_finder, node_features=node_features,
						edge_features=edge_features, device=device,
						n_layers=NUM_LAYER,
						n_heads=NUM_HEADS, dropout=DROP_OUT, use_memory=USE_MEMORY,
						message_dimension=MESSAGE_DIM, memory_dimension=MEMORY_DIM,
						memory_update_at_start=not args.memory_update_at_end,
						embedding_module_type=args.embedding_module,
						message_function=args.message_function,
						aggregator_type=args.aggregator, n_neighbors=NUM_NEIGHBORS,
						mean_time_shift_src=mean_time_shift_src, std_time_shift_src=std_time_shift_src,
						mean_time_shift_dst=mean_time_shift_dst, std_time_shift_dst=std_time_shift_dst,
						use_destination_embedding_in_message=args.use_destination_embedding_in_message,
						use_source_embedding_in_message=args.use_source_embedding_in_message)

	tgn = tgn.to(device)

	num_instance = len(train_data.sources)
	num_batch = math.ceil(num_instance / BATCH_SIZE)
	
	logger.debug('Num of training instances: {}'.format(num_instance))
	logger.debug('Num of batches per epoch: {}'.format(num_batch))

	logger.info('Loading saved TGN model')
	model_path = f'./saved_models/{args.prefix}-{DATA}.pth'
	tgn.load_state_dict(torch.load(model_path))
	tgn.eval()
	logger.info('TGN models loaded')
	logger.info('Start training node classification task')

	decoder = MLP(node_features.shape[1], drop=DROP_OUT)
	decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.lr)
	decoder = decoder.to(device)
	decoder_loss_criterion = torch.nn.BCELoss()

	val_aucs = []
	train_losses = []

	early_stopper = EarlyStopMonitor(max_round=args.patience)
	for epoch in range(args.n_epoch):
		start_epoch = time.time()
		
		# Initialize memory of the model at each epoch
		if USE_MEMORY:
			tgn.memory.__init_memory__()

		tgn = tgn.eval()
		decoder = decoder.train()
		loss = 0
		
		for k in range(num_batch):
			s_idx = k * BATCH_SIZE
			e_idx = min(num_instance, s_idx + BATCH_SIZE)

			sources_batch = train_data.sources[s_idx: e_idx]
			destinations_batch = train_data.destinations[s_idx: e_idx]
			timestamps_batch = train_data.timestamps[s_idx: e_idx]
			edge_idxs_batch = full_data.edge_idxs[s_idx: e_idx]
			labels_batch = train_data.labels[s_idx: e_idx]

			size = len(sources_batch)

			decoder_optimizer.zero_grad()
			with torch.no_grad():
				source_embedding, destination_embedding, _ = tgn.compute_temporal_embeddings(sources_batch,
																																										 destinations_batch,
																																										 destinations_batch,
																																										 timestamps_batch,
																																										 edge_idxs_batch,
																																										 NUM_NEIGHBORS)

			labels_batch_torch = torch.from_numpy(labels_batch).float().to(device)
			pred = decoder(source_embedding).sigmoid()
			decoder_loss = decoder_loss_criterion(pred, labels_batch_torch)
			decoder_loss.backward()
			decoder_optimizer.step()
			loss += decoder_loss.item()
		train_losses.append(loss / num_batch)

		val_auc = eval_node_classification(tgn, decoder, val_data, full_data.edge_idxs, BATCH_SIZE,
																			 n_neighbors=NUM_NEIGHBORS)
		val_aucs.append(val_auc)

		pickle.dump({
			"val_aps": val_aucs,
			"train_losses": train_losses,
			"epoch_times": [0.0],
			"new_nodes_val_aps": [],
		}, open(results_path, "wb"))

		logger.info(f'Epoch {epoch}: train loss: {loss / num_batch}, val auc: {val_auc}, time: {time.time() - start_epoch}')
	
	if args.use_validation:
		if early_stopper.early_stop_check(val_auc):
			logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
			break
		else:
			torch.save(decoder.state_dict(), get_checkpoint_path(epoch))

	if args.use_validation:
		logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
		best_model_path = get_checkpoint_path(early_stopper.best_epoch)
		decoder.load_state_dict(torch.load(best_model_path))
		logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
		decoder.eval()

		test_auc = eval_node_classification(tgn, decoder, test_data, full_data.edge_idxs, BATCH_SIZE,
																				n_neighbors=NUM_NEIGHBORS)
	else:
		# If we are not using a validation set, the test performance is just the performance computed
		# in the last epoch
		test_auc = val_aucs[-1]
		
	pickle.dump({
		"val_aps": val_aucs,
		"test_ap": test_auc,
		"train_losses": train_losses,
		"epoch_times": [0.0],
		"new_nodes_val_aps": [],
		"new_node_test_ap": 0,
	}, open(results_path, "wb"))

	logger.info(f'test auc: {test_auc}')
