import logging
import time
import sys
import os
import csv
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
from tqdm import tqdm
import numpy as np
import warnings
import shutil
import json
import torch
import torch.nn as nn
import pandas as pd
import argparse
import sys
from utils import set_random_seed, convert_to_gpu, get_parameter_sizes, create_optimizer
from utils import get_neighbor_sampler, NegativeEdgeSampler
from DataLoader import get_idx_data_loader,get_raw_data



def get_args(is_evaluation: bool = False):
    """
    get the args for the link prediction task
    :param is_evaluation: boolean, whether in the evaluation process
    :return:
    """
    # arguments
    parser = argparse.ArgumentParser('Interface for the link prediction task')
    parser.add_argument('--dataset_name', type=str, help='dataset to be used', default='wikipedia',
                        choices=[ 'wikipedia', 'reddit', 'mooc', 'lastfm', 'enron', 'SocialEvo', 'uci', 'Flights', 'CanParl', 'USLegis', 'UNtrade', 'UNvote', 'Contacts'])
    parser.add_argument('--batch_size', type=int, default=200, help='batch size')
    parser.add_argument('--num_neighbors', type=int, default=20, help='number of neighbors to sample for each node')
    parser.add_argument('--sample_neighbor_strategy', type=str, default='recent', choices=['uniform', 'recent', 'time_interval_aware'], help='how to sample historical neighbors')
    parser.add_argument('--time_scaling_factor', default=1e-6, type=float, help='the hyperparameter that controls the sampling preference with time interval, '
                        'a large time_scaling_factor tends to sample more on recent links, 0.0 corresponds to uniform sampling, '
                        'it works when sample_neighbor_strategy == time_interval_aware')
    parser.add_argument('--num_walk_heads', type=int, default=8, help='number of heads used for the attention in walk encoder')
    parser.add_argument('--patch_size', type=int, default=1, help='patch size')
    parser.add_argument('--channel_embedding_dim', type=int, default=50, help='dimension of each channel embedding')
    parser.add_argument('--max_input_sequence_length', type=int, default=32, help='maximal length of the input sequence of each node')
    parser.add_argument('--patience', type=int, default=20, help='patience for early stopping')
    parser.add_argument('--val_ratio', type=float, default=0.15, help='ratio of validation set')
    parser.add_argument('--test_ratio', type=float, default=0.15, help='ratio of test set')
    parser.add_argument('--num_runs', type=int, default=5, help='number of runs')
    parser.add_argument('--test_interval_epochs', type=int, default=1, help='how many epochs to perform testing once')
    parser.add_argument('--negative_sample_strategy', type=str, default='random', choices=['random', 'historical', 'inductive'],
                        help='strategy for the negative edge sampling')
    parser.add_argument('--load_best_configs', action='store_true', default=False, help='whether to load the best configurations')

    args = parser.parse_args()
    
    return args


if __name__ == "__main__":

    warnings.filterwarnings('ignore')
    # get arguments
    args = get_args(is_evaluation=False)

    
      
    # get data for training, validation and testing
    node_raw_features, edge_raw_features, full_data, _ = \
        get_raw_data(dataset_name=args.dataset_name)
    

    # initialize neighbor sampler to retrieve temporal graph
    max_neighbors, full_neighbor_sampler = get_neighbor_sampler(data=full_data, sample_neighbor_strategy=args.sample_neighbor_strategy,
                                                time_scaling_factor=args.time_scaling_factor, seed=1)

    node_interact_times = full_data.node_interact_times

    src_node_ids = full_data.src_node_ids
    src_save_path = f'graph_properties/{args.dataset_name}/src'
    if not os.path.exists(src_save_path):
        os.makedirs(src_save_path)
    full_neighbor_sampler.compute_multi_hop_neighbor_graphs_properties(src_save_path, num_hops=3, node_ids=src_node_ids, node_interact_times=node_interact_times)
    
    dst_node_ids = full_data.dst_node_ids
    dst_save_path = f'graph_properties/{args.dataset_name}/dst'
    if not os.path.exists(dst_save_path):
        os.makedirs(dst_save_path)
    full_neighbor_sampler.compute_multi_hop_neighbor_graphs_properties(dst_save_path, num_hops=3, node_ids=dst_node_ids, node_interact_times=node_interact_times)
    
