import math
import logging
import time
import sys
import argparse
import torch
import numpy as np
import pickle
from pathlib import Path

from evaluation import eval_edge_prediction
from utils.utils import EarlyStopMonitor, RandEdgeSampler
from utils.data_processing import get_data, compute_time_statistics
from sklearn.metrics import average_precision_score, roc_auc_score
from ogn import OGN

torch.manual_seed(0)
np.random.seed(0)

### Argument and global variables
parser = argparse.ArgumentParser('Self-supervised training')
parser.add_argument('-d', '--data', type=str, help='Dataset name (eg. wikipedia or reddit)',
                    default='wikipedia')
parser.add_argument('--bs', type=int, default=200, help='Batch_size')
parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints')
parser.add_argument('--n_epoch', type=int, default=100, help='Number of epochs')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--n_runs', type=int, default=1, help='Number of runs')
parser.add_argument('--drop_out', type=float, default=0.1, help='Dropout probability')
parser.add_argument('--gpu', type=int, default=0, help='Idx for the gpu to use')
parser.add_argument('--backprop_every', type=int, default=1, help='Every how many batches to '
                                                                  'backprop')
parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping')

parser.add_argument('--different_new_nodes', action='store_true',
                    help='Whether to use disjoint set of new nodes for train and val')
parser.add_argument('--randomize_features', action='store_true',
                    help='Whether to randomize node features')

parser.add_argument('--filter_users', action='store_true',
                    help='Whether to filter users to make all users have the same number of values.')
parser.add_argument('--set_ts_to_zero', action='store_true',
                    help='Whether to set all ts to 0.')
parser.add_argument('--set_ts_to_uniform', action='store_true',
                    help='Whether to set all ts to uniformly spaced.')
parser.add_argument('--initial_batch', type=int, default=0, help='The initial batch to use in each epoch.')
parser.add_argument('--plot_timestamp_distribution', action='store_true',
                    help='Whether to plot the distribution for the timestamps for each user.')
parser.add_argument('--randomize_edge_features', action='store_true',
                    help='Whether to change the edge features to random.')
parser.add_argument('--erase_edge_features', action='store_true',
                    help='Whether to set edge features to 0s.')
parser.add_argument('--activation', type=str, default='tanh', choices=['tanh', 'relu'],
                    help='The activation layer for the model.')
parser.add_argument('--update_type', type=str, default='mean', choices=['mean', 'harmonic'],
                    help='The update type for the neighbor embedding.')
parser.add_argument('--harmonic_weight', type=float, default=0.6,
                    help='The weight for the harmonic mean.')
parser.add_argument('--set_uniform', action='store_true',
                    help='Whether to set all ts to uniformly spaced.')
parser.add_argument('--alpha', type=float, default=1e-8,
                    help='Alpha for the exponent of OGN.')
parser.add_argument('--batch_norm', action='store_true',
                    help='Whether to use batch norm.')
parser.add_argument('--dropout', action='store_true',
                    help='Whether to use dropout.')
parser.add_argument('--dropout_probability', type=float, default=0.1,
                    help='Dropout probability.')
parser.add_argument('--bins', type=int, default=10,
                    help='Number of bins in the one hot encoding for the data.')
parser.add_argument('--max_value_bin', type=int, default=1000,
                    help='Maximum value for the bin in the dataset.')
parser.add_argument('--plot_distribution', action='store_true',
                    help='Plot the distribution for the positive and negative nodes.')
parser.add_argument('--ignore_time', action='store_false',
                    help='Ignore the time embedding in the model.')
parser.add_argument('--perturb', action='store_true',
                    help='Perturbate timestamps.')
parser.add_argument('--remove_neighbors', action='store_true',
                    help='Remove neighbors.')
try:
    args = parser.parse_args()
except:
    parser.print_help()
    sys.exit(0)


BATCH_SIZE = args.bs
NUM_NEG = 1
NUM_EPOCH = args.n_epoch
DROP_OUT = args.drop_out
GPU = args.gpu
DATA = args.data
LEARNING_RATE = args.lr

Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}.pth'
get_checkpoint_path = lambda \
        epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}.pth'

### set up logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
Path("log/").mkdir(parents=True, exist_ok=True)
fh = logging.FileHandler('log/{}.log'.format(str(time.time())))
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.WARN)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
logger.info(args)

### Extract data for training, validation and testing
node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, \
new_node_test_data = get_data(DATA,
                              different_new_nodes_between_val_and_test=args.different_new_nodes,
                              randomize_features=args.randomize_features,
                              filter_users=args.filter_users, set_ts_to_uniform=args.set_ts_to_uniform,
                              plot_distribution=args.plot_timestamp_distribution,
                              randomize_edge_features=args.randomize_edge_features,
                              erase_edge_features=args.erase_edge_features,
                              perturb_timestamps=args.perturb)
destination_values = full_data.destinations
unique_values = np.unique(destination_values).size
destination_min = np.min(destination_values)

# Initialize negative samplers. Set seeds for validation and testing so negatives are the same
# across different runs
# NB: in the inductive setting, negatives are sampled only amongst other new nodes
train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations, seed=0)
val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0)
nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations,
                                      seed=1)
test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2)
nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources,
                                       new_node_test_data.destinations,
                                       seed=3)

# Set device
device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu'
device = torch.device(device_string)

# Compute time statistics
mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \
    compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps,
                            set_uniform=args.set_uniform)

for i in range(args.n_runs):
    results_path = "results/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/{}.pkl".format(args.prefix)
    Path("results/").mkdir(parents=True, exist_ok=True)

    # Initialize Model
    model = OGN(node_features=node_features,
                edge_features=edge_features,
                device=device,
                harmonic_weight=args.harmonic_weight,
                activation=args.activation,
                update_type=args.update_type,
                mean_time_shift_src=mean_time_shift_src,
                std_time_shift_src=std_time_shift_src,
                mean_time_shift_dst=mean_time_shift_dst,
                std_time_shift_dst=std_time_shift_dst,
                uniform_time=args.set_uniform,
                alpha=args.alpha,
                batch_norm=args.batch_norm,
                dropout=args.dropout,
                dropout_probability=args.dropout_probability,
                bins=args.bins,
                max_value=args.max_value_bin,
                consider_time=args.ignore_time,
                remove_neighbors=args.remove_neighbors)

    params = list(model.parameters())

    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    model = model.to(device)

    num_instance = len(train_data.sources)
    num_batch = math.ceil(num_instance / BATCH_SIZE)

    logger.info('num of training instances: {}'.format(num_instance))
    logger.info('num of batches per epoch: {}'.format(num_batch))
    idx_list = np.arange(num_instance)

    new_nodes_val_aps = []
    val_aucs = []
    train_aucs = []
    train_aps = []
    val_aps = []
    epoch_times = []
    total_epoch_times = []
    train_losses = []
    new_nodes_val_accuracies = []
    val_accuracies = []
    train_accuracies = []

    early_stopper = EarlyStopMonitor(max_round=args.patience)
    for epoch in range(NUM_EPOCH):
        start_epoch = time.time()
        ### Training
        model._reset_counter()
        model._reset_memory()
        m_loss = []

        logger.info('start {} epoch'.format(epoch))
        train_ap_values = []
        train_auc_values = []
        for k in range(args.initial_batch, num_batch, args.backprop_every):
            loss = 0
            optimizer.zero_grad()

            # Custom loop to allow to perform backpropagation only every a certain number of batches
            for j in range(args.backprop_every):
                batch_idx = k + j

                if batch_idx >= num_batch:
                    continue

                start_idx = batch_idx * BATCH_SIZE
                end_idx = min(num_instance, start_idx + BATCH_SIZE)
                sources_batch, destinations_batch = train_data.sources[start_idx:end_idx], \
                                                    train_data.destinations[start_idx:end_idx]
                edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx]
                timestamps_batch = train_data.timestamps[start_idx:end_idx]

                size = len(sources_batch)
                _, negatives_batch = train_rand_sampler.sample(size)

                with torch.no_grad():
                    pos_label = torch.ones(size, dtype=torch.float, device=device)
                    neg_label = torch.zeros(size, dtype=torch.float, device=device) + \
                                torch.tensor(destinations_batch == negatives_batch, device=device)

                model = model.train()
                pos_prob, neg_prob = model.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch,
                                                                    timestamps_batch, edge_idxs_batch)
                pred_score = np.concatenate([(pos_prob).detach().cpu().numpy(), (neg_prob).detach().cpu().numpy()])
                true_label = np.concatenate([np.ones(size), np.zeros(size) + (destinations_batch == negatives_batch)])
                train_ap_values.append(average_precision_score(true_label, pred_score))
                train_auc_values.append(roc_auc_score(true_label, pred_score))
                loss += criterion(pos_prob.squeeze(), pos_label.double()) + criterion(neg_prob.squeeze(), neg_label.double())

            loss /= args.backprop_every

            loss.backward()
            optimizer.step()
            m_loss.append(loss.item())

        train_ap = average_precision_score(true_label, pred_score)
        train_auc = roc_auc_score(true_label, pred_score)
        epoch_time = time.time() - start_epoch
        epoch_times.append(epoch_time)

        val_ap, val_auc, _ = eval_edge_prediction(model=model,
                                                  negative_edge_sampler=val_rand_sampler,
                                                  data=val_data)

        val_aps.append(val_ap)
        val_aucs.append(val_auc)
        train_losses.append(np.mean(m_loss))
        train_aps.append(train_ap)
        train_aucs.append(train_auc)

        total_epoch_time = time.time() - start_epoch
        total_epoch_times.append(total_epoch_time)
        # Save temporary results to disk
        pickle.dump({
            "val_aps": val_aps,
            "val_aucs": val_aucs,
            "new_nodes_val_aps": new_nodes_val_aps,
            "train_losses": train_losses,
            "epoch_times": epoch_times,
            "total_epoch_times": total_epoch_times,
            "train_aps": train_aps,
            "train_aucs": train_aucs,
        }, open(results_path, "wb"))

        logger.info('epoch: {} took {:.2f}s'.format(epoch, total_epoch_time))
        logger.info('Epoch mean loss: {:.2f}'.format(np.mean(m_loss)))
        logger.info(f'train auc: {train_auc:.2f}, train ap: {train_ap:.2f}')
        logger.info(
            'val auc: {:.2f}'.format(val_auc))
        logger.info(
            'val ap: {:.2f}'.format(val_ap))

        # Early stopping
        if early_stopper.early_stop_check(val_ap):
            logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
            logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
            best_model_path = get_checkpoint_path(early_stopper.best_epoch)
            model.load_state_dict(torch.load(best_model_path))
            logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
            model.eval()
            break
        else:
            torch.save(model.state_dict(), get_checkpoint_path(epoch))
    features = model.save_features()

    # Test
    test_ap, test_auc, test_accuracy = eval_edge_prediction(model=model,
                                                            negative_edge_sampler=test_rand_sampler,
                                                            data=test_data)
    # Test on unseen nodes
    model.load_features(features)
    test_ap_nn, test_auc_nn, test_accuracy_nn = eval_edge_prediction(model=model,
                                                                     negative_edge_sampler=nn_test_rand_sampler,
                                                                     data=new_node_test_data)
    model.load_features(features)

    logger.info(
        f'{args.data} - {args.activation} - {args.alpha}')
    logger.info(
        'Test statistics (Random): Old nodes -- Test AP: {:.2f}, Test AUC: {:.2f}, Test accuracy: {:.2f}'.format(test_ap, test_auc, test_accuracy))
    logger.info(
        'Test statistics (Random): New nodes -- Test AP: {:.2f}, Test AUC: {:.2f}, Test accuracy: {:.2f}'.format(test_ap_nn, test_auc_nn, test_accuracy_nn))

    # Save results for this run
    pickle.dump({
        "val_aps": val_aps,
        "val_aucs": val_aucs,
        "new_nodes_val_aps": new_nodes_val_aps,
        "train_losses": train_losses,
        "epoch_times": epoch_times,
        "total_epoch_times": total_epoch_times,
        "train_aps": train_aps,
        "train_aucs": train_aucs,
        "test_ap": test_ap,
        "test_auc": test_auc,
        "test_ap_nn": test_ap_nn,
        "test_auc_nn": test_auc_nn,
        "test_accuracy": test_accuracy,
        "test_accuracy_nn": test_accuracy_nn,
    }, open(results_path, "wb"))

    logger.info('Saving model')
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    logger.info('Model saved')
