import logging
import torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import pandas as pd
import csv


class CustomizedDataset(Dataset):
    def __init__(self, indices_list: list):
        """
        Customized dataset.
        :param indices_list: list, list of indices
        """
        super(CustomizedDataset, self).__init__()

        self.indices_list = indices_list

    def __getitem__(self, idx: int):
        """
        get item at the index in self.indices_list
        :param idx: int, the index
        :return:
        """
        return self.indices_list[idx]

    def __len__(self):
        return len(self.indices_list)


def get_idx_data_loader(indices_list: list, batch_size: int, shuffle: bool):
    """
    get data loader that iterates over indices
    :param indices_list: list, list of indices
    :param batch_size: int, batch size
    :param shuffle: boolean, whether to shuffle the data
    :return: data_loader, DataLoader
    """
    dataset = CustomizedDataset(indices_list=indices_list)

    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             drop_last=False)
    return data_loader

class HyperData:

    def __init__(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray, node_interact_times: np.ndarray, edge_ids: np.ndarray, labels: np.ndarray,src_hyperbolicity:np.ndarray, dst_hyperbolicity:np.ndarray,src_geo_feature:torch.Tensor,dst_geo_feature:torch.Tensor):
        """
        Data object to store the nodes interaction information.
        :param src_node_ids: ndarray
        :param dst_node_ids: ndarray
        :param node_interact_times: ndarray
        :param edge_ids: ndarray
        :param labels: ndarray
        """
        self.src_node_ids = src_node_ids
        self.dst_node_ids = dst_node_ids
        self.node_interact_times = node_interact_times
        self.edge_ids = edge_ids
        self.labels = labels
        self.src_hyperbolicity=src_hyperbolicity
        self.dst_hyperbolicity=dst_hyperbolicity
        self.src_geo_feature = src_geo_feature
        self.dst_geo_feature = dst_geo_feature
        self.num_interactions = len(src_node_ids)
        self.unique_node_ids = set(src_node_ids) | set(dst_node_ids)
        self.num_unique_nodes = len(self.unique_node_ids)

class Data:

    def __init__(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray, node_interact_times: np.ndarray,
                 edge_ids: np.ndarray, labels: np.ndarray):
        """
        Data object to store the nodes interaction information.
        :param src_node_ids: ndarray
        :param dst_node_ids: ndarray
        :param node_interact_times: ndarray
        :param edge_ids: ndarray
        :param labels: ndarray
        """
        self.src_node_ids = src_node_ids
        self.dst_node_ids = dst_node_ids
        self.node_interact_times = node_interact_times
        self.edge_ids = edge_ids
        self.labels = labels
        self.num_interactions = len(src_node_ids)
        self.unique_node_ids = set(src_node_ids) | set(dst_node_ids)
        self.num_unique_nodes = len(self.unique_node_ids)


def convert_discrete_time_to_second(df: pd.DataFrame, granularity: int):
    """
    multiply the timestamps of discrete-time dynamic graph by the given time granularity
    :param df: pd.DataFrame, a temporal graph
    :param granularity: new time granularity
    :return:
    """
    ts = sorted(list(set(df['ts'].tolist())))
    rank = np.arange(len(ts), dtype=np.float64) * granularity
    ts2rank = dict(zip(ts, rank))
    df['ts'] = df['ts'].map(lambda x: ts2rank[x])
    return df


def get_link_prediction_data(dataset_name: str, val_ratio: float, test_ratio: float, logger: logging.Logger,
                             convert_time=False,device ='cpu'):
    """
    generate data for link prediction task (inductive & transductive settings)
    :param dataset_name: str, dataset name
    :param val_ratio: float, validation data ratio
    :param test_ratio: float, test data ratio
    :param logger:
    :return: node_raw_features, edge_raw_features, (np.ndarray),
            full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data, (Data object)
    """
    # Load data and train val test split
    graph_df = pd.read_csv('TG_network_datasets/{}/ml_{}.csv'.format(dataset_name, dataset_name))
    edge_raw_features = np.load('TG_network_datasets/{}/ml_{}.npy'.format(dataset_name, dataset_name))
    node_raw_features = np.load('TG_network_datasets/{}/ml_{}_node.npy'.format(dataset_name, dataset_name))

    # convert the time granularity of discrete-time dynamic graph to second
    # for example, if the original time granularity is day, we multiply timestamps by 86400
    if convert_time and dataset_name in ['Flights', 'CanParl', 'USLegis', 'UNtrade', 'UNvote', 'Contacts']:
        if dataset_name in ['Flights']:  # days
            graph_df = convert_discrete_time_to_second(graph_df, 86400)
        elif dataset_name in ['CanParl', 'USLegis', 'UNtrade', 'UNvote']:  # years
            graph_df = convert_discrete_time_to_second(graph_df, 86400 * 365)
        elif dataset_name in ['Contacts']:  # 5 minutes
            graph_df = convert_discrete_time_to_second(graph_df, 300)
        else:
            raise ValueError("Not Recognized Discrete-Time Dataset!")

    NODE_FEAT_DIM = EDGE_FEAT_DIM = 172
    assert NODE_FEAT_DIM >= node_raw_features.shape[
        1], f'Node feature dimension in dataset {dataset_name} is bigger than {NODE_FEAT_DIM}!'
    assert EDGE_FEAT_DIM >= edge_raw_features.shape[
        1], f'Edge feature dimension in dataset {dataset_name} is bigger than {EDGE_FEAT_DIM}!'
    # padding the features of edges and nodes to the same dimension (172 for all the datasets)
    if node_raw_features.shape[1] < NODE_FEAT_DIM:
        node_zero_padding = np.zeros((node_raw_features.shape[0], NODE_FEAT_DIM - node_raw_features.shape[1]))
        node_raw_features = np.concatenate([node_raw_features, node_zero_padding], axis=1)
    if edge_raw_features.shape[1] < EDGE_FEAT_DIM:
        edge_zero_padding = np.zeros((edge_raw_features.shape[0], EDGE_FEAT_DIM - edge_raw_features.shape[1]))
        edge_raw_features = np.concatenate([edge_raw_features, edge_zero_padding], axis=1)

    assert NODE_FEAT_DIM == node_raw_features.shape[1] and EDGE_FEAT_DIM == edge_raw_features.shape[
        1], 'Unaligned feature dimensions after feature padding!'

    # get the timestamp of validate and test set
    val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)]))

    src_node_ids = graph_df.u.values.astype(np.longlong)
    dst_node_ids = graph_df.i.values.astype(np.longlong)
    node_interact_times = graph_df.ts.values.astype(np.float64)
    edge_ids = graph_df.idx.values.astype(np.longlong)
    labels = graph_df.label.values
    
    src_hyperbolicity,  dst_hyperbolicity, src_geo_feature, dst_geo_feature = get_hyperbolic_data(dataset_name = dataset_name)
        
    src_geo_feature, dst_geo_feature = torch.tensor(src_geo_feature, device = device), torch.tensor(dst_geo_feature,device = device)
    src_hyperbolicity = src_hyperbolicity.to(device)
    dst_hyperbolicity = dst_hyperbolicity.to(device)


    min_id = min(min(src_node_ids), min(dst_node_ids))
    max_id = max(max(src_node_ids), max(dst_node_ids))
    assert max_id - min_id + 1 == len(set(src_node_ids) | set(dst_node_ids))

    full_data = HyperData(src_node_ids=src_node_ids, dst_node_ids=dst_node_ids, node_interact_times=node_interact_times, edge_ids=edge_ids, labels=labels,src_hyperbolicity = src_hyperbolicity,dst_hyperbolicity= dst_hyperbolicity,src_geo_feature=src_geo_feature, dst_geo_feature=dst_geo_feature)


    # the setting of seed follows previous works
    random.seed(2020)

    # union to get node set
    node_set = set(src_node_ids) | set(dst_node_ids)
    num_total_unique_node_ids = len(node_set)

    # compute nodes which appear at test time
    test_node_set = set(src_node_ids[node_interact_times > val_time]).union(
        set(dst_node_ids[node_interact_times > val_time]))
    # sample nodes which we keep as new nodes (to test inductiveness), so then we have to remove all their edges from training
    # if you use the python version 3.12, please change
    new_test_node_set = set(random.sample(test_node_set, int(0.1 * num_total_unique_node_ids)))
    # into the following code
    # new_test_node_set = set(random.sample(sorted(test_node_set), int(0.1 * num_total_unique_node_ids)))

    # mask for each source and destination to denote whether they are new test nodes
    new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
    new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values

    # mask, which is true for edges with both destination and source not being new test nodes (because we want to remove all edges involving any new test node)
    observed_edges_mask = np.logical_and(~new_test_source_mask, ~new_test_destination_mask)

    # for train data, we keep edges happening before the validation time which do not involve any new node, used for inductiveness
    train_mask = np.logical_and(node_interact_times <= val_time, observed_edges_mask)

    train_data = HyperData(src_node_ids=src_node_ids[train_mask], dst_node_ids=dst_node_ids[train_mask],
                      node_interact_times=node_interact_times[train_mask],
                      edge_ids=edge_ids[train_mask], labels=labels[train_mask],
                      src_hyperbolicity = src_hyperbolicity[train_mask],dst_hyperbolicity= dst_hyperbolicity[train_mask],
                      src_geo_feature=src_geo_feature[train_mask], dst_geo_feature=dst_geo_feature[train_mask])

    # define the new nodes sets for testing inductiveness of the model
    train_node_set = set(train_data.src_node_ids).union(train_data.dst_node_ids)
    assert len(train_node_set & new_test_node_set) == 0
    # new nodes that are not in the training set
    new_node_set = node_set - train_node_set

    val_mask = np.logical_and(node_interact_times <= test_time, node_interact_times > val_time)
    test_mask = node_interact_times > test_time

    # new edges with new nodes in the val and test set (for inductive evaluation)
    edge_contains_new_node_mask = np.array([(src_node_id in new_node_set or dst_node_id in new_node_set)
                                            for src_node_id, dst_node_id in zip(src_node_ids, dst_node_ids)])
    new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask)
    new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask)

    # validation and test data
    val_data = HyperData(src_node_ids=src_node_ids[val_mask], dst_node_ids=dst_node_ids[val_mask],
                    node_interact_times=node_interact_times[val_mask], edge_ids=edge_ids[val_mask], labels=labels[val_mask],
                    src_hyperbolicity = src_hyperbolicity[val_mask],dst_hyperbolicity= dst_hyperbolicity[val_mask],
                    src_geo_feature=src_geo_feature[val_mask], dst_geo_feature=dst_geo_feature[val_mask])

    test_data = HyperData(src_node_ids=src_node_ids[test_mask], dst_node_ids=dst_node_ids[test_mask],
                     node_interact_times=node_interact_times[test_mask], edge_ids=edge_ids[test_mask], labels=labels[test_mask],
                     src_hyperbolicity = src_hyperbolicity[test_mask],dst_hyperbolicity= dst_hyperbolicity[test_mask],
                     src_geo_feature=src_geo_feature[test_mask], dst_geo_feature=dst_geo_feature[test_mask])

    # validation and test with edges that at least has one new node (not in training set)
    new_node_val_data = HyperData(src_node_ids=src_node_ids[new_node_val_mask], dst_node_ids=dst_node_ids[new_node_val_mask],
                            node_interact_times=node_interact_times[new_node_val_mask],
                            edge_ids=edge_ids[new_node_val_mask], labels=labels[new_node_val_mask],
                            src_hyperbolicity = src_hyperbolicity[new_node_val_mask],dst_hyperbolicity= dst_hyperbolicity[new_node_val_mask],
                            src_geo_feature=src_geo_feature[new_node_val_mask], dst_geo_feature=dst_geo_feature[new_node_val_mask])

    new_node_test_data = HyperData(src_node_ids=src_node_ids[new_node_test_mask], dst_node_ids=dst_node_ids[new_node_test_mask],
                              node_interact_times=node_interact_times[new_node_test_mask],
                              edge_ids=edge_ids[new_node_test_mask], labels=labels[new_node_test_mask],
                              src_hyperbolicity = src_hyperbolicity[new_node_test_mask],dst_hyperbolicity= dst_hyperbolicity[new_node_test_mask],
                              src_geo_feature=src_geo_feature[new_node_test_mask], dst_geo_feature=dst_geo_feature[new_node_test_mask])

    src_geo_feature = torch.cat((torch.zeros(1,src_geo_feature.shape[1]).to(device),src_geo_feature),dim=0)
    dst_geo_feature = torch.cat((torch.zeros(1,dst_geo_feature.shape[1]).to(device),dst_geo_feature),dim=0)



    logger.info("The dataset has {} interactions, involving {} different nodes".format(full_data.num_interactions,
                                                                                       full_data.num_unique_nodes))
    logger.info("The training dataset has {} interactions, involving {} different nodes".format(
        train_data.num_interactions, train_data.num_unique_nodes))
    logger.info("The validation dataset has {} interactions, involving {} different nodes".format(
        val_data.num_interactions, val_data.num_unique_nodes))
    logger.info("The test dataset has {} interactions, involving {} different nodes".format(
        test_data.num_interactions, test_data.num_unique_nodes))
    logger.info("The new node validation dataset has {} interactions, involving {} different nodes".format(
        new_node_val_data.num_interactions, new_node_val_data.num_unique_nodes))
    logger.info("The new node test dataset has {} interactions, involving {} different nodes".format(
        new_node_test_data.num_interactions, new_node_test_data.num_unique_nodes))
    logger.info("{} nodes were used for the inductive testing, i.e. are never seen during training".format(
        len(new_test_node_set)))

    return node_raw_features, edge_raw_features, src_geo_feature, dst_geo_feature, full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data

def get_node_class_data(dataset_name: str, val_ratio: float, test_ratio: float, logger: logging.Logger,
                             convert_time=False,device ='cpu'):
    """
    generate data for link prediction task (inductive & transductive settings)
    :param dataset_name: str, dataset name
    :param val_ratio: float, validation data ratio
    :param test_ratio: float, test data ratio
    :param logger:
    :return: node_raw_features, edge_raw_features, (np.ndarray),
            full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data, (Data object)
    """
    # Load data and train val test split
    graph_df = pd.read_csv('TG_network_datasets/{}/ml_{}.csv'.format(dataset_name, dataset_name))
    edge_raw_features = np.load('TG_network_datasets/{}/ml_{}.npy'.format(dataset_name, dataset_name))
    node_raw_features = np.load('TG_network_datasets/{}/ml_{}_node.npy'.format(dataset_name, dataset_name))

    # convert the time granularity of discrete-time dynamic graph to second
    # for example, if the original time granularity is day, we multiply timestamps by 86400
    if convert_time and dataset_name in ['Flights', 'CanParl', 'USLegis', 'UNtrade', 'UNvote', 'Contacts']:
        if dataset_name in ['Flights']:  # days
            graph_df = convert_discrete_time_to_second(graph_df, 86400)
        elif dataset_name in ['CanParl', 'USLegis', 'UNtrade', 'UNvote']:  # years
            graph_df = convert_discrete_time_to_second(graph_df, 86400 * 365)
        elif dataset_name in ['Contacts']:  # 5 minutes
            graph_df = convert_discrete_time_to_second(graph_df, 300)
        else:
            raise ValueError("Not Recognized Discrete-Time Dataset!")

    NODE_FEAT_DIM = EDGE_FEAT_DIM = 172
    assert NODE_FEAT_DIM >= node_raw_features.shape[
        1], f'Node feature dimension in dataset {dataset_name} is bigger than {NODE_FEAT_DIM}!'
    assert EDGE_FEAT_DIM >= edge_raw_features.shape[
        1], f'Edge feature dimension in dataset {dataset_name} is bigger than {EDGE_FEAT_DIM}!'
    # padding the features of edges and nodes to the same dimension (172 for all the datasets)
    if node_raw_features.shape[1] < NODE_FEAT_DIM:
        node_zero_padding = np.zeros((node_raw_features.shape[0], NODE_FEAT_DIM - node_raw_features.shape[1]))
        node_raw_features = np.concatenate([node_raw_features, node_zero_padding], axis=1)
    if edge_raw_features.shape[1] < EDGE_FEAT_DIM:
        edge_zero_padding = np.zeros((edge_raw_features.shape[0], EDGE_FEAT_DIM - edge_raw_features.shape[1]))
        edge_raw_features = np.concatenate([edge_raw_features, edge_zero_padding], axis=1)

    assert NODE_FEAT_DIM == node_raw_features.shape[1] and EDGE_FEAT_DIM == edge_raw_features.shape[
        1], 'Unaligned feature dimensions after feature padding!'

    # get the timestamp of validate and test set
    val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)]))

    src_node_ids = graph_df.u.values.astype(np.longlong)
    dst_node_ids = graph_df.i.values.astype(np.longlong)
    node_interact_times = graph_df.ts.values.astype(np.float64)
    edge_ids = graph_df.idx.values.astype(np.longlong)
    labels = graph_df.label.values
    
    src_hyperbolicity,  dst_hyperbolicity, src_geo_feature, dst_geo_feature = get_hyperbolic_data(dataset_name = dataset_name)
        
    src_geo_feature, dst_geo_feature = torch.tensor(src_geo_feature, device = device), torch.tensor(dst_geo_feature,device = device)
    src_hyperbolicity = src_hyperbolicity.to(device)
    dst_hyperbolicity = dst_hyperbolicity.to(device)
    # src_hyperbolicity = np.array(src_hyperbolicity)
    # dst_hyperbolicity = np.array(dst_hyperbolicity)

    min_id = min(min(src_node_ids), min(dst_node_ids))
    max_id = max(max(src_node_ids), max(dst_node_ids))
    assert max_id - min_id + 1 == len(set(src_node_ids) | set(dst_node_ids))

    full_data = HyperData(src_node_ids=src_node_ids, dst_node_ids=dst_node_ids, node_interact_times=node_interact_times, edge_ids=edge_ids, labels=labels,src_hyperbolicity = src_hyperbolicity,dst_hyperbolicity= dst_hyperbolicity,src_geo_feature=src_geo_feature, dst_geo_feature=dst_geo_feature)


    # the setting of seed follows previous works
    random.seed(2020)

    # union to get node set
    train_mask = node_interact_times <= val_time
    val_mask = np.logical_and(node_interact_times <= test_time, node_interact_times > val_time)
    test_mask = node_interact_times > test_time

   

    train_data = HyperData(src_node_ids=src_node_ids[train_mask], dst_node_ids=dst_node_ids[train_mask],
                      node_interact_times=node_interact_times[train_mask],
                      edge_ids=edge_ids[train_mask], labels=labels[train_mask],
                      src_hyperbolicity = src_hyperbolicity[train_mask],dst_hyperbolicity= dst_hyperbolicity[train_mask],
                      src_geo_feature=src_geo_feature[train_mask], dst_geo_feature=dst_geo_feature[train_mask])



    # validation and test data
    val_data = HyperData(src_node_ids=src_node_ids[val_mask], dst_node_ids=dst_node_ids[val_mask],
                    node_interact_times=node_interact_times[val_mask], edge_ids=edge_ids[val_mask], labels=labels[val_mask],
                    src_hyperbolicity = src_hyperbolicity[val_mask],dst_hyperbolicity= dst_hyperbolicity[val_mask],
                    src_geo_feature=src_geo_feature[val_mask], dst_geo_feature=dst_geo_feature[val_mask])

    test_data = HyperData(src_node_ids=src_node_ids[test_mask], dst_node_ids=dst_node_ids[test_mask],
                     node_interact_times=node_interact_times[test_mask], edge_ids=edge_ids[test_mask], labels=labels[test_mask],
                     src_hyperbolicity = src_hyperbolicity[test_mask],dst_hyperbolicity= dst_hyperbolicity[test_mask],
                     src_geo_feature=src_geo_feature[test_mask], dst_geo_feature=dst_geo_feature[test_mask])

    src_geo_feature = torch.cat((torch.zeros(1,src_geo_feature.shape[1]).to(device),src_geo_feature),dim=0)
    dst_geo_feature = torch.cat((torch.zeros(1,dst_geo_feature.shape[1]).to(device),dst_geo_feature),dim=0)


    return node_raw_features, edge_raw_features, src_geo_feature, dst_geo_feature, full_data, train_data, val_data, test_data



def get_hyperbolic_data(dataset_name):
    dirpath = f'graph_properties/{dataset_name}/'
    src_path = dirpath + 'src'
    dst_path = dirpath + 'dst'

    

    def load_geo_data(save_path):
        df = pd.read_csv(save_path + '/geo_metrics.csv')
        data_np = df.values
        data_tensor = torch.tensor(data_np, dtype=torch.float32)
        return data_tensor
    
    

    src_geo_feature = load_geo_data(src_path)
    # src_geo_feature = torch.tensor(src_geo_feature)
    src_hyperbolicity = torch.tensor(src_geo_feature[:,0])
    
    # src_geo_feature = torch.cat((torch.tensor(src_hyperbolicity).unsqueeze(1), src_geo_feature), dim=1)
    # src_geo_feature = torch.cat((torch.zeros(1,src_geo_feature.shape[1]),src_geo_feature),dim=0)
    dst_geo_feature = load_geo_data(src_path)
    dst_hyperbolicity = torch.tensor(dst_geo_feature[:,0])
    # dst_geo_feature = torch.cat((torch.tensor(dst_hyperbolicity).unsqueeze(1), dst_geo_feature), dim=1)
    # dst_geo_feature = torch.cat((torch.zeros(1,dst_geo_feature.shape[1]),dst_geo_feature),dim=0)

    scaler = StandardScaler()
    # src_geo_feature = src_geo_feature.numpy()
    src_geo_feature = scaler.fit_transform(src_geo_feature)
    src_geo_feature = torch.tensor(src_geo_feature,dtype=torch.float32)
    # dst_geo_feature = dst_geo_feature.numpy()
    dst_geo_feature = scaler.fit_transform(dst_geo_feature)
    dst_geo_feature = torch.tensor(dst_geo_feature,dtype=torch.float32)
    

    # src_hyperbolicity.insert(0,0)
    # dst_hyperbolicity.insert(0,0)
    
    
    
    return  src_hyperbolicity,  dst_hyperbolicity, src_geo_feature, dst_geo_feature

    