import numpy as np
import torch

from tgnnexplainer.xgraph.models.ext.tgat.graph import NeighborFinder

class MergeLayer(torch.nn.Module):
    def __init__(self, dim1, dim2, dim3, dim4):
        super().__init__()
        self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
        self.fc2 = torch.nn.Linear(dim3, dim4)
        self.act = torch.nn.ReLU()

        torch.nn.init.xavier_normal_(self.fc1.weight)
        torch.nn.init.xavier_normal_(self.fc2.weight)

    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        h = self.act(self.fc1(x))
        return self.fc2(h)


class MLP(torch.nn.Module):
    def __init__(self, dim, drop=0.3):
        super().__init__()
        self.fc_1 = torch.nn.Linear(dim, 80)
        self.fc_2 = torch.nn.Linear(80, 10)
        self.fc_3 = torch.nn.Linear(10, 1)
        self.act = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=drop, inplace=False)

    def forward(self, x):
        x = self.act(self.fc_1(x))
        x = self.dropout(x)
        x = self.act(self.fc_2(x))
        x = self.dropout(x)
        return self.fc_3(x).squeeze(dim=1)


class EarlyStopMonitor(object):
    def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
        self.max_round = max_round
        self.num_round = 0

        self.epoch_count = 0
        self.best_epoch = 0

        self.last_best = None
        self.higher_better = higher_better
        self.tolerance = tolerance

    def early_stop_check(self, curr_val):
        if not self.higher_better:
            curr_val *= -1
        if self.last_best is None:
            self.last_best = curr_val
        elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
            self.last_best = curr_val
            self.num_round = 0
            self.best_epoch = self.epoch_count
        else:
            self.num_round += 1

        self.epoch_count += 1

        return self.num_round >= self.max_round


class RandEdgeSampler(object):
    def __init__(self, src_list, dst_list, seed=None):
        self.seed = None
        self.src_list = np.unique(src_list)
        self.dst_list = np.unique(dst_list)

        if seed is not None:
            self.seed = seed
            self.random_state = np.random.RandomState(self.seed)

    def sample(self, size):
        if self.seed is None:
            src_index = np.random.randint(0, len(self.src_list), size)
            dst_index = np.random.randint(0, len(self.dst_list), size)
        else:

            src_index = self.random_state.randint(0, len(self.src_list), size)
            dst_index = self.random_state.randint(0, len(self.dst_list), size)
        return self.src_list[src_index], self.dst_list[dst_index]

    def reset_random_state(self):
        self.random_state = np.random.RandomState(self.seed)


def get_neighbor_finder(data, uniform, max_node_idx=None):
    max_node_idx = max(data.sources.max(), data.destinations.max()) if max_node_idx is None else max_node_idx
    adj_list = [[] for _ in range(max_node_idx + 1)]
    for source, destination, edge_idx, timestamp in zip(data.sources, data.destinations,
                                                        data.edge_idxs,
                                                        data.timestamps):
        adj_list[source].append((destination, edge_idx, timestamp))
        adj_list[destination].append((source, edge_idx, timestamp))

    # return NeighborFinder(adj_list, uniform=uniform)
    return NeighborFinder(adj_list, uniform=False)


# class NeighborFinder:
# 	def __init__(self, adj_list, uniform=False, seed=None):
# 		self.node_to_neighbors = []
# 		self.node_to_edge_idxs = []
# 		self.node_to_edge_timestamps = []

# 		for neighbors in adj_list:
# 			# Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
# 			# We sort the list based on timestamp
# 			sorted_neighhbors = sorted(neighbors, key=lambda x: x[2])
# 			self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors]))
# 			self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))
# 			self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))

# 		self.uniform = uniform

# 		if seed is not None:
# 			self.seed = seed
# 			self.random_state = np.random.RandomState(self.seed)

# 	def find_before(self, src_idx, cut_time):
# 		"""
# 		Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time.

# 		Returns 3 lists: neighbors, edge_idxs, timestamps

# 		"""
# 		i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)

# 		return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[src_idx][:i]

# 	def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):
# 		"""
# 		Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list.

# 		Params
# 		------
# 		src_idx_l: List[int]
# 		cut_time_l: List[float],
# 		num_neighbors: int
# 		"""
# 		assert (len(source_nodes) == len(timestamps))

# 		tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1
# 		# NB! All interactions described in these matrices are sorted in each row by time
# 		neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
# 			np.int32)  # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i]
# 		edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
# 			np.float32)  # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
# 		edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
# 			np.int32)  # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]

# 		for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):
# 			source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node,
# 																									 timestamp)  # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_time

# 			if len(source_neighbors) > 0 and n_neighbors > 0:
# 				if self.uniform:  # if we are applying uniform sampling, shuffles the data above before sampling
# 					sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)

# 					neighbors[i, :] = source_neighbors[sampled_idx]
# 					edge_times[i, :] = source_edge_times[sampled_idx]
# 					edge_idxs[i, :] = source_edge_idxs[sampled_idx]

# 					# re-sort based on time
# 					pos = edge_times[i, :].argsort()
# 					neighbors[i, :] = neighbors[i, :][pos]
# 					edge_times[i, :] = edge_times[i, :][pos]
# 					edge_idxs[i, :] = edge_idxs[i, :][pos]
# 				else:
# 					# Take most recent interactions
# 					source_edge_times = source_edge_times[-n_neighbors:]
# 					source_neighbors = source_neighbors[-n_neighbors:]
# 					source_edge_idxs = source_edge_idxs[-n_neighbors:]

# 					assert (len(source_neighbors) <= n_neighbors)
# 					assert (len(source_edge_times) <= n_neighbors)
# 					assert (len(source_edge_idxs) <= n_neighbors)

# 					neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors
# 					edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times
# 					edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs

# 		return neighbors, edge_idxs, edge_times