import math
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt


class Data:
    def __init__(self, sources, destinations, timestamps, edge_idxs, labels):
        self.sources = sources
        self.destinations = destinations
        self.timestamps = timestamps
        self.edge_idxs = edge_idxs
        self.labels = labels
        self.n_interactions = len(sources)
        self.unique_nodes = set(sources) | set(destinations)
        self.n_unique_nodes = len(self.unique_nodes)


def perturb_timestamp(timestamp: np.array) -> np.array:
    """
    Perturb the timestamp of the input by changing to a value between [t_{i-1}, t_{i}].

    :param timestamp:   The timestamp to perturb.
    :return:            The perturbed timestamp.
    """
    timestamp = timestamp.astype('float64')
    deltas = np.diff(timestamp)
    deltas = np.insert(deltas, 0, 0)
    timestamp += deltas * np.random.rand(deltas.size)

    return timestamp


def get_data_node_classification(dataset_name, use_validation=False, visualize_banned=True):
    ### Load data and train val test split
    graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name))
    edge_features = np.load('./data/ml_{}.npy'.format(dataset_name))
    node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name))

    val_time, test_time = list(np.quantile(graph_df.ts, [0.70, 0.85]))
    val_index, test_index = list(np.quantile(graph_df.index.to_series(), [0.70, 0.85]))

    # Get the number of interactions before ban
    count_per_user = graph_df.groupby('u').count()
    unique_users = graph_df['u'].unique()
    unique_reddit = graph_df['i'].unique()
    print(graph_df['label'].sum())

    # Visualize how many of the banned people are in the training dataset
    if visualize_banned:
        banned_users = graph_df[abs(1.0 - graph_df['label']) < 0.01]
        index_of_bans = banned_users.index.to_series().reset_index(drop=True)
        x_values = index_of_bans.size
        plt.plot(index_of_bans, 'b.-')
        plt.plot([0, x_values], [val_index, val_index], 'r-', label='Validation threshold')
        plt.plot([0, x_values], [test_index, test_index], 'k-', label='Test threshold')
        plt.xlim([0, x_values])
        plt.legend()
        plt.grid()
        plt.ylabel('Value for the index')
        plt.show()

    sources = graph_df.u.values
    destinations = graph_df.i.values
    edge_idxs = graph_df.idx.values
    labels = graph_df.label.values
    timestamps = graph_df.ts.values

    random.seed(2020)

    train_mask = timestamps <= val_time if use_validation else timestamps <= test_time
    test_mask = timestamps > test_time
    val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time) if use_validation else test_mask

    full_data = Data(sources, destinations, timestamps, edge_idxs, labels)

    train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask],
                      edge_idxs[train_mask], labels[train_mask])

    val_data = Data(sources[val_mask], destinations[val_mask], timestamps[val_mask],
                    edge_idxs[val_mask], labels[val_mask])

    test_data = Data(sources[test_mask], destinations[test_mask], timestamps[test_mask],
                     edge_idxs[test_mask], labels[test_mask])

    return full_data, node_features, edge_features, train_data, val_data, test_data


def get_data(dataset_name, different_new_nodes_between_val_and_test=False, randomize_features=False,
             filter_users=False, minimum_values=10, maximum_values=30, set_ts_to_zero=False,
             set_ts_to_uniform=False, plot_distribution=True, randomize_edge_features=False,
             erase_edge_features=False, perturb_timestamps=False):
    ### Load data and train val test split
    graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name))
    if plot_distribution:
        users, time_differences = compute_average_time_differences_users(graph_df, dataset=dataset_name)
    if set_ts_to_uniform:
        graph_df['ts'] = graph_df.index * 0.1

    edge_features = np.load('./data/ml_{}.npy'.format(dataset_name))
    node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name))

    if randomize_edge_features:
        edge_features = np.random.rand(edge_features.shape[0], edge_features.shape[1])
    if erase_edge_features:
        edge_features = np.zeros((edge_features.shape[0], edge_features.shape[1]))
    if randomize_features:
        node_features = np.random.rand(node_features.shape[0], node_features.shape[1])

    val_time, test_time = list(np.quantile(graph_df.ts, [0.75, 0.85]))

    sources = graph_df.u.values
    destinations = graph_df.i.values
    edge_idxs = graph_df.idx.values
    labels = graph_df.label.values
    timestamps = graph_df.ts.values

    full_data = Data(sources, destinations, timestamps, edge_idxs, labels)

    random.seed(2020)

    node_set = set(sources) | set(destinations)
    n_total_unique_nodes = len(node_set)

    # Compute nodes which appear at test time
    test_node_set = set(sources[timestamps > val_time]).union(
        set(destinations[timestamps > val_time]))
    # Sample nodes which we keep as new nodes (to test inductiveness), so than we have to remove all
    # their edges from training
    new_test_node_set = set(random.sample(test_node_set, int(0.1 * n_total_unique_nodes)))

    # Mask saying for each source and destination whether they are new test nodes
    new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
    new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values

    # Mask which is true for edges with both destination and source not being new test nodes (because
    # we want to remove all edges involving any new test node)
    observed_edges_mask = np.logical_and(~new_test_source_mask, ~new_test_destination_mask)

    # For train we keep edges happening before the validation time which do not involve any new node
    # used for inductiveness
    train_mask = np.logical_and(timestamps <= val_time, observed_edges_mask)

    train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask],
                      edge_idxs[train_mask], labels[train_mask])

    # define the new nodes sets for testing inductiveness of the model
    train_node_set = set(train_data.sources).union(train_data.destinations)
    assert len(train_node_set & new_test_node_set) == 0
    new_node_set = node_set - train_node_set

    val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time)
    test_mask = timestamps > test_time

    if different_new_nodes_between_val_and_test:
        n_new_nodes = len(new_test_node_set) // 2
        val_new_node_set = set(list(new_test_node_set)[:n_new_nodes])
        test_new_node_set = set(list(new_test_node_set)[n_new_nodes:])

        edge_contains_new_val_node_mask = np.array(
            [(a in val_new_node_set or b in val_new_node_set) for a, b in zip(sources, destinations)])
        edge_contains_new_test_node_mask = np.array(
            [(a in test_new_node_set or b in test_new_node_set) for a, b in zip(sources, destinations)])
        new_node_val_mask = np.logical_and(val_mask, edge_contains_new_val_node_mask)
        new_node_test_mask = np.logical_and(test_mask, edge_contains_new_test_node_mask)

    else:
        edge_contains_new_node_mask = np.array(
            [(a in new_node_set or b in new_node_set) for a, b in zip(sources, destinations)])
        new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask)
        new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask)

    # validation and test with all edges
    if set_ts_to_zero:
        graph_df.ts = 0
        timestamps = graph_df.ts.values
        full_data.timestamps = timestamps
        train_data.timestamps = timestamps

    val_timestamps = perturb_timestamp(timestamps[val_mask]) if perturb_timestamps else timestamps[val_mask]
    test_timestamps = perturb_timestamp(timestamps[test_mask]) if perturb_timestamps else timestamps[test_mask]
    new_node_val_timestamps = perturb_timestamp(timestamps[new_node_val_mask]) if perturb_timestamps else timestamps[new_node_val_mask]
    new_node_test_timestamps = perturb_timestamp(timestamps[new_node_test_mask]) if perturb_timestamps else timestamps[new_node_test_mask]
    if perturb_timestamps:
        full_data.timestamps[val_mask] = val_timestamps
        full_data.timestamps[test_mask] = test_timestamps
        full_data.timestamps[new_node_val_mask] = new_node_val_timestamps
        full_data.timestamps[new_node_test_mask] = new_node_test_timestamps
    val_data = Data(sources[val_mask], destinations[val_mask], val_timestamps,
                    edge_idxs[val_mask], labels[val_mask])

    test_data = Data(sources[test_mask], destinations[test_mask], test_timestamps,
                     edge_idxs[test_mask], labels[test_mask])

    # validation and test with edges that at least has one new node (not in training set)
    new_node_val_data = Data(sources[new_node_val_mask], destinations[new_node_val_mask],
                             new_node_val_timestamps, edge_idxs[new_node_val_mask], labels[new_node_val_mask])

    new_node_test_data = Data(sources[new_node_test_mask], destinations[new_node_test_mask],
                              new_node_test_timestamps, edge_idxs[new_node_test_mask],
                              labels[new_node_test_mask])

    print("Total #events: {}  -  #nodes: {}".format(full_data.n_interactions, full_data.n_unique_nodes))
    print("[Train] #events: {}  -  #nodes: {}".format(train_data.n_interactions, train_data.n_unique_nodes))
    print("[VAl] #vents: {}  -  #nodes: {}".format(val_data.n_interactions, val_data.n_unique_nodes))
    print("[Test] #events: {}  -  #nodes: {}".format(test_data.n_interactions, test_data.n_unique_nodes))
    print("[New-Val] #events: {}  -  #nodes: {}".format( new_node_val_data.n_interactions, new_node_val_data.n_unique_nodes))
    print("[New-Test] #events: {}  -  #nodes: {}".format(new_node_test_data.n_interactions, new_node_test_data.n_unique_nodes))
    print("{} nodes were used for the inductive testing, i.e. are never seen during training".format(
        len(new_test_node_set)))

    return node_features, edge_features, full_data, train_data, val_data, test_data, \
           new_node_val_data, new_node_test_data


def compute_time_statistics(sources, destinations, timestamps, set_uniform=False, dt=0.1):
    if set_uniform:
        timestamps = np.arange(timestamps.size) * dt
    last_timestamp_sources = dict()
    last_timestamp_dst = dict()
    all_timediffs_src = []
    all_timediffs_dst = []
    for k in range(len(sources)):
        source_id = sources[k]
        dest_id = destinations[k]
        c_timestamp = timestamps[k]
        if source_id not in last_timestamp_sources.keys():
            last_timestamp_sources[source_id] = 0
        if dest_id not in last_timestamp_dst.keys():
            last_timestamp_dst[dest_id] = 0
        all_timediffs_src.append(c_timestamp - last_timestamp_sources[source_id])
        all_timediffs_dst.append(c_timestamp - last_timestamp_dst[dest_id])
        last_timestamp_sources[source_id] = c_timestamp
        last_timestamp_dst[dest_id] = c_timestamp
    assert len(all_timediffs_src) == len(sources)
    assert len(all_timediffs_dst) == len(sources)
    mean_time_shift_src = np.mean(all_timediffs_src)
    std_time_shift_src = np.std(all_timediffs_src)
    mean_time_shift_dst = np.mean(all_timediffs_dst)
    std_time_shift_dst = np.std(all_timediffs_dst)

    return mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst


def compute_average_time_difference(timestamps: pd.Series) -> float:
    """
    Compute the average time difference between adjacent timestamps.

    :param timestamps:  The pandas series with the timestamps.
    :return:            The average of the differences between the timestamps.
    """
    differences = timestamps.diff()
    return differences.mean()


def compute_average_time_differences_users(complete_data: pd.DataFrame, dataset: str,
                                           plot_results: bool = True) -> (list, list):
    """
    Compute the average time difference for every user in the data frame.

    :param complete_data:   The complete dataframe with the users and timestamps for every event.
    :param dataset:         The name of the dataset we are plotting.
    :param plot_results:    Whether to plot the obtained results.
    :return:                A tuple of two lists. The first list has the users. The second list has float values
                            corresponding to the average time difference for each user.
    """
    user_list = complete_data['u'].unique().tolist()
    average_time_differences = []

    for user in user_list:
        data_user = complete_data[complete_data['u'] == user]
        average_time_difference = compute_average_time_difference(data_user['ts'])
        if math.isnan(average_time_difference):
            continue
        else:
            average_time_differences.append(average_time_difference)

    if plot_results:
        time_differences_copy = average_time_differences
        time_differences_copy.sort(reverse=True)
        plt.plot(time_differences_copy, '.')
        plt.title(f'Event distribution for dataset: {dataset}')
        plt.xlabel('User ID')
        plt.ylabel('Average time difference')
        plt.grid()
        plt.show()

    return user_list, average_time_differences