import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from time_encode import TimeEncode


class MLP(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, activation='relu', batch_norm=True, dropout=False,
                 dropout_probability=0.5):
        super(MLP, self).__init__()
        self.first_linear = nn.Linear(input_size, hidden_size).double()
        self.second_linear = nn.Linear(hidden_size, output_size).double()
        self.batch_norm = batch_norm
        self.dropout = dropout
        if dropout:
            self.first_dropout_layer = nn.Dropout(dropout_probability / 2)
            self.second_dropout_layer = nn.Dropout(dropout_probability)
        if batch_norm:
            self.first_batch_norm = nn.BatchNorm1d(input_size).double()
            self.second_batch_norm = nn.BatchNorm1d(hidden_size).double()
        if activation is None:
            self.activation = nn.Identity()
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()

    def forward(self, x):
        if self.dropout:
            x = self.first_dropout_layer(x)
        x = self.activation(self.first_linear(x))
        if self.dropout:
            x = self.second_dropout_layer(x)
        if self.batch_norm:
            x = self.second_batch_norm(x)
        return self.second_linear(x)


class Linear(nn.Module):
    def __init__(self, input_size, output_size, batch_norm=False, activation=None):
        super(Linear, self).__init__()
        self.linear = nn.Linear(input_size, output_size).double()
        self.batch_norm = batch_norm
        self.batch_norm_layer = nn.BatchNorm1d(input_size).double()
        if activation is None:
            self.activation = nn.Identity()
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()

    def forward(self, x):
        if self.batch_norm:
            return self.activation(self.linear(self.batch_norm_layer(x.double())))
        else:
            return self.activation(self.linear(x.double()))


class LinearMixer(nn.Module):
    def __init__(self, source_size, destination_size, hidden_size, activation=None, batch_norm=False):
        super(LinearMixer, self).__init__()
        self.linear_source = nn.Linear(source_size, hidden_size).double()
        self.linear_destination = nn.Linear(destination_size, hidden_size).double()
        self.batch_norm = batch_norm
        if batch_norm:
            self.batch_norm_source = nn.BatchNorm1d(source_size).double()
            self.batch_norm_destination = nn.BatchNorm1d(destination_size).double()
        if activation is None:
            self.activation = nn.Identity()
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()

    def forward(self, source, destination):
        if self.batch_norm:
            return self.activation(self.linear_source(self.batch_norm_source(source.double()))
                                   + self.linear_destination(self.batch_norm_destination(destination.double())))
        else:
            return self.activation(self.linear_source(source.double()) + self.linear_destination(destination.double()))


class OGN(nn.Module):
    def __init__(self, node_features, edge_features, device, mean_time_shift_src, std_time_shift_src,
                 mean_time_shift_dst, std_time_shift_dst, randomize_node=False, dt=0.1,
                 activation='tanh', epsilon=1e-8, update_type='mean', harmonic_weight=0.99,
                 uniform_time=True, alpha=1e-4, dropout=False, dropout_probability=0.3, batch_norm=True,
                 max_value=1000, bins=10, save_degree=False, consider_time=True, remove_neighbors=False):
        super(OGN, self).__init__()
        if randomize_node:
            self.node_features = nn.Parameter(torch.randn(*node_features.shape, device=device, requires_grad=False))
        else:
            self.node_features = torch.tensor(node_features, device=device)
        self.remove_neighbors = remove_neighbors
        self.visits = torch.zeros(self.node_features.shape[0], device=device)
        self.neighbor_visits = torch.zeros(self.node_features.shape[0], bins, device=device)
        self.original_features = self.node_features.detach()
        self.A = torch.zeros(*node_features.shape, device=device).double()
        self.B = torch.zeros(self.node_features.shape[0], 1, device=device).double() + epsilon
        self.mean_time_shift_src = mean_time_shift_src
        self.mean_time_shift_dst = mean_time_shift_dst
        self.std_time_shift_src = std_time_shift_src
        self.std_time_shirt_dst = std_time_shift_dst
        self.n_nodes = self.node_features.shape[0]
        self.n_node_features = self.node_features.shape[1]
        self.n_neighbors = torch.zeros(self.n_nodes, 1, device=device) + epsilon
        self.edge_features = torch.tensor(edge_features, device=device)
        self.n_edge_features = self.edge_features.shape[1]
        self.last_event_times = torch.zeros(self.n_nodes, device=device)
        self.processed_events = 0
        self.dt = dt
        self.time_encode = TimeEncode(self.n_node_features)
        self.mixer = Linear(self.n_node_features * 2, self.n_node_features)
        self.embedding = LinearMixer(self.n_node_features, self.n_node_features, self.n_node_features,
                                     activation=activation)
        self.classification_mlp = MLP(self.n_node_features, 1, self.n_node_features,
                                      dropout_probability=dropout_probability, dropout=dropout, batch_norm=batch_norm)
        self.update_neighbors_src = LinearMixer(self.n_node_features, self.n_node_features, self.n_node_features,
                                                activation='tanh')
        self.update_neighbors_dest = LinearMixer(self.n_node_features, self.n_node_features, self.n_node_features,
                                                 activation='tanh')
        self.embed_edges = LinearMixer(self.n_node_features, self.n_edge_features, self.n_node_features,
                                       activation='relu')
        self.embed_neighbors = Linear(self.n_node_features, self.n_node_features, activation='tanh')
        self.add_degree = LinearMixer(self.n_node_features, self.n_node_features, self.n_node_features,
                                      activation='relu')
        self.embed_structure = LinearMixer(bins, bins, self.n_node_features, activation='relu')
        self.embed_degree = Linear(bins, self.n_node_features, activation='tanh').double()
        self.uniform_time = uniform_time
        self.device = device
        self.update_type = update_type
        self.harmonic_weight = harmonic_weight
        self.t = 1
        self.alpha = alpha
        self.epsilon = epsilon
        self.bins = torch.tensor(np.logspace(start=0, stop=max_value, num=bins), device=device)
        self.num_bins = bins
        self.save_degree = save_degree
        self.consider_time = consider_time
        self.degree_list = {
            'positive': [],
            'negative': []
        }

    def find_position(self, values):
        positions = (values.unsqueeze(1) <= self.bins.unsqueeze(0))
        idx = torch.arange(positions.shape[1], 0, -1, device=self.device)
        tmp = positions * idx
        indices = torch.argmax(tmp, 1, keepdim=True)
        one_hot = F.one_hot(indices, self.num_bins).squeeze(1)

        return one_hot

    def _reset_counter(self):
        self.processed_events = 0

    def _reset_memory(self):
        self.A = torch.zeros(*self.node_features.shape, device=self.device).double()
        self.B = torch.zeros(self.node_features.shape[0], 1, device=self.device).double() + self.epsilon
        self.last_event_times = torch.zeros(self.n_nodes, device=self.device)
        self.processed_events = 0
        self.node_features = self.original_features
        self.visits = torch.zeros(self.node_features.shape[0], device=self.device)
        self.neighbor_visits = torch.zeros(self.node_features.shape[0], self.num_bins, device=self.device)

    def get_neighbors(self, sources, destinations, negatives):
        """
        Gets the summary embedding for the neighbors that has interacted with the node.

        :param sources:         The source node ids. [B]
        :param destinations:    The destination node ids. [B]
        :param negatives:       The negative sample node ids. [B]
        :return:                Three tensors ([B, F]) with the summarized information of the neighbors of each node.
        """
        source_neighbors = self.embed_neighbors(self.A[sources, :] / self.B[sources, :])
        destination_neighbors = self.embed_neighbors(self.A[destinations, :] / self.B[destinations, :])
        negative_neighbors = self.embed_neighbors(self.A[negatives, :] / self.B[negatives, :])

        return source_neighbors, destination_neighbors, negative_neighbors

    def update_neighbors_embeddings(self, sources, destinations, time_difference_source, time_difference_destination,
                                    time_difference_negative):
        """
        Update the neighborhood embedding summary information.
        If 'self.update_type' is 'mean', add the neighbor node information to the node.
        If 'self.update_type' is 'harmonic', update using the corrected harmonic mean.

        :param sources:         The source node ids. [B]
        :param destinations:    The destination node ids. [B]
        """
        self.A[sources, :] = self.A[sources, :] * torch.exp(-self.alpha * time_difference_source.unsqueeze(1)) + self.node_features[destinations, :]
        self.B[sources, :] = self.B[sources, :] * torch.exp(-self.alpha * time_difference_source.unsqueeze(1)) + 1
        self.A[destinations, :] = self.A[destinations, :] * torch.exp(-self.alpha * time_difference_destination.unsqueeze(1)) + self.node_features[sources, :]
        self.B[destinations, :] = self.B[destinations, :] * torch.exp(-self.alpha * time_difference_destination.unsqueeze(1)) + 1

    def update_features(self, sources, destinations, negatives, source_embeddings, destination_embeddings,
                        negative_embeddings, time_difference_source, time_difference_destination,
                        time_difference_negative):
        """
        Update the features of the nodes by adding neighbor information.

        :param sources:                 The source node ids. [B]
        :param destinations:            The destination node ids. [B]
        :param negatives:               The negative sample node ids. [B]
        :param source_embeddings:       The source node features. [B, F]
        :param destination_embeddings:  The destination node features. [B, F]
        :param negative_embeddings:     The negative sample node features. [B, F]
        :return:                        Three tensors ([B, F]) with the updated features for the source, destination
                                        and negative sample embeddings.
        """
        source_neighbors, destination_neighbors, negative_neighbors = self.get_neighbors(sources, destinations,
                                                                                         negatives)
        # Set the embeddings with the neighbor information
        if self.remove_neighbors:
            source_embeddings_update = source_embeddings
            destination_embeddings_update = destination_embeddings
            negative_embeddings_update = negative_embeddings
        else:
            source_embeddings_update = self.update_neighbors_src(source_embeddings, source_neighbors)
            destination_embeddings_update = self.update_neighbors_dest(destination_embeddings, destination_neighbors)
            negative_embeddings_update = self.update_neighbors_dest(negative_embeddings, negative_neighbors)

        # Update the source embeddings
        self.node_features[sources, :] = source_embeddings_update.detach().double()
        self.node_features[destinations, :] = destination_embeddings_update.detach().double()

        # Update the neighbor information
        self.update_neighbors_embeddings(sources, destinations, time_difference_source, time_difference_destination,
                                         time_difference_negative)

        return source_embeddings_update, destination_embeddings_update, negative_embeddings_update

    def save_features(self):
        return torch.clone(self.node_features)

    def load_features(self, node_features):
        self.node_features = node_features

    def get_degree_list(self):
        return self.degree_list

    def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                                   edge_idxs):
        source_size = source_nodes.shape[0]

        # Get the embeddings
        source_embeddings = self.node_features[source_nodes, :]
        destination_embeddings = self.node_features[destination_nodes, :]
        negative_embeddings = self.node_features[negative_nodes, :]
        edge_embeddings = self.edge_features[edge_idxs]

        # Get the times
        last_events_source = self.last_event_times[source_nodes]
        last_events_destination = self.last_event_times[destination_nodes]
        last_events_negative = self.last_event_times[negative_nodes]

        # Get the new events
        if self.uniform_time:
            new_events_ts = (torch.arange(source_size, device=self.device) + 1 + self.processed_events) * self.dt
        else:
            new_events_ts = torch.tensor(edge_times, device=self.device).float()
        self.processed_events += source_size
        self.last_event_times[source_nodes] = new_events_ts.detach()
        self.last_event_times[destination_nodes] = new_events_ts.detach()

        # Calculate the time difference
        time_difference_source_original = new_events_ts - last_events_source
        time_difference_source = (time_difference_source_original - self.mean_time_shift_src) / self.std_time_shift_src
        time_difference_destination_original = new_events_ts - last_events_destination
        time_difference_destination = (time_difference_destination_original - self.mean_time_shift_dst) / self.std_time_shirt_dst
        time_difference_negative_original = new_events_ts - last_events_negative
        time_difference_negative = (time_difference_negative_original - self.mean_time_shift_dst) / self.std_time_shirt_dst

        if self.consider_time:
            # Compute the time encoding
            time_encoding_source = self.time_encode(time_difference_source.unsqueeze(1)).squeeze()
            time_encoding_destination = self.time_encode(time_difference_destination.unsqueeze(1)).squeeze()
            time_encoding_negative = self.time_encode(time_difference_negative.unsqueeze(1)).squeeze()

            # Compute the representation with time embeddings
            time_embed_source = self.mixer(torch.cat([source_embeddings.float(), time_encoding_source], dim=1))
            time_embed_destination = self.mixer(torch.cat([destination_embeddings.float(), time_encoding_destination], dim=1))
            time_embed_negative = self.mixer(torch.cat([negative_embeddings.float(), time_encoding_negative], dim=1))
        else:
            time_embed_source = source_embeddings.float()
            time_embed_destination = destination_embeddings.float()
            time_embed_negative = negative_embeddings.float()

        # Add the edge embedding
        edge_embed_source = self.embed_edges(time_embed_source, edge_embeddings)
        edge_embed_destination = self.embed_edges(time_embed_destination, edge_embeddings)
        edge_embed_negative = self.embed_edges(time_embed_negative, edge_embeddings)

        source_embeddings_update, destination_embeddings_update, \
        negative_embedding_update = self.update_features(source_nodes, destination_nodes, negative_nodes,
                                                         edge_embed_source, edge_embed_destination, edge_embed_negative,
                                                         time_difference_source_original, time_difference_destination_original,
                                                         time_difference_negative_original)
        # Get the positions for the bins
        index_sources = self.find_position(self.visits[source_nodes]).double()
        index_destination = self.find_position(self.visits[destination_nodes]).double()
        index_negative = self.find_position(self.visits[negative_nodes]).double()
        self.degree_list['positive'].append(self.visits[destination_nodes])
        self.degree_list['negative'].append(self.visits[negative_nodes])

        # Get the structural information
        structure_sources = self.neighbor_visits[source_nodes] / (torch.max(self.neighbor_visits[source_nodes], dim=1, keepdim=True)[0] + self.epsilon)
        structure_destination = self.neighbor_visits[destination_nodes] / (torch.max(self.neighbor_visits[destination_nodes], dim=1, keepdim=True)[0] + self.epsilon)
        structure_negative = self.neighbor_visits[negative_nodes] / (torch.max(self.neighbor_visits[negative_nodes], dim=1, keepdim=True)[0] + self.epsilon)

        # Mix the structure information
        source_structure_embedding = self.embed_structure(structure_sources, index_sources)
        destination_structure_embedding = self.embed_structure(structure_destination, index_destination)
        negative_structure_embedding = self.embed_structure(structure_negative, index_negative)

        # Update the structural information
        self.neighbor_visits[source_nodes, :] += index_destination
        self.neighbor_visits[destination_nodes, :] += index_sources
        self.visits[source_nodes] += 1
        self.visits[destination_nodes] += 1

        # Pass it through the mixer
        source_embeddings_index = self.add_degree(source_embeddings_update, source_structure_embedding)
        destination_embeddings_index = self.add_degree(destination_embeddings_update, destination_structure_embedding)
        negative_embeddings_index = self.add_degree(negative_embedding_update, negative_structure_embedding)
        sources = torch.cat([source_embeddings_index, source_embeddings_index])
        destinations = torch.cat([destination_embeddings_index, negative_embeddings_index])

        # Mix the values
        embeddings = self.embedding(sources, destinations)
        edge_embeddings = torch.cat([edge_embeddings, edge_embeddings]).float()

        # Get the output
        output = self.classification_mlp(embeddings)

        # Get positive and negative values
        positive = output[:source_size, :]
        negative = output[source_size:, :]

        if torch.isnan(positive).any():
            print('Done')

        return positive.sigmoid(), negative.sigmoid()

    def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                                   edge_idxs):
        source_size = source_nodes.shape[0]

        # Get the embeddings
        source_embeddings = self.node_features[source_nodes, :]
        destination_embeddings = self.node_features[destination_nodes, :]
        negative_embeddings = self.node_features[negative_nodes, :]
        edge_embeddings = self.edge_features[edge_idxs]

        # Get the times
        last_events_source = self.last_event_times[source_nodes]
        last_events_destination = self.last_event_times[destination_nodes]
        last_events_negative = self.last_event_times[negative_nodes]

        # Get the new events
        if self.uniform_time:
            new_events_ts = (torch.arange(source_size, device=self.device) + 1 + self.processed_events) * self.dt
        else:
            new_events_ts = torch.tensor(edge_times, device=self.device).float()
        self.processed_events += source_size
        self.last_event_times[source_nodes] = new_events_ts.detach()
        self.last_event_times[destination_nodes] = new_events_ts.detach()

        # Calculate the time difference
        time_difference_source_original = new_events_ts - last_events_source
        time_difference_source = (time_difference_source_original - self.mean_time_shift_src) / self.std_time_shift_src
        time_difference_destination_original = new_events_ts - last_events_destination
        time_difference_destination = (time_difference_destination_original - self.mean_time_shift_dst) / self.std_time_shirt_dst
        time_difference_negative_original = new_events_ts - last_events_negative
        time_difference_negative = (time_difference_negative_original - self.mean_time_shift_dst) / self.std_time_shirt_dst

        # Compute the time encoding
        time_encoding_source = self.time_encode(time_difference_source.unsqueeze(1)).squeeze()
        time_encoding_destination = self.time_encode(time_difference_destination.unsqueeze(1)).squeeze()
        time_encoding_negative = self.time_encode(time_difference_negative.unsqueeze(1)).squeeze()

        # Compute the representation with time embeddings
        time_embed_source = self.mixer(torch.cat([source_embeddings.float(), time_encoding_source], dim=1))
        time_embed_destination = self.mixer(torch.cat([destination_embeddings.float(), time_encoding_destination], dim=1))
        time_embed_negative = self.mixer(torch.cat([negative_embeddings.float(), time_encoding_negative], dim=1))

        # Add the edge embedding
        edge_embed_source = self.embed_edges(time_embed_source, edge_embeddings)
        edge_embed_destination = self.embed_edges(time_embed_destination, edge_embeddings)
        edge_embed_negative = self.embed_edges(time_embed_negative, edge_embeddings)

        source_embeddings_update, destination_embeddings_update, \
        negative_embedding_update = self.update_features(source_nodes, destination_nodes, negative_nodes,
                                                         edge_embed_source, edge_embed_destination, edge_embed_negative,
                                                         time_difference_source_original, time_difference_destination_original,
                                                         time_difference_negative_original)
        # Get the positions for the bins
        index_sources = self.find_position(self.visits[source_nodes]).double()
        index_destination = self.find_position(self.visits[destination_nodes]).double()
        index_negative = self.find_position(self.visits[negative_nodes]).double()
        self.degree_list['positive'].append(self.visits[destination_nodes])
        self.degree_list['negative'].append(self.visits[negative_nodes])

        # Get the structural information
        structure_sources = self.neighbor_visits[source_nodes] / (torch.max(self.neighbor_visits[source_nodes], dim=1, keepdim=True)[0] + self.epsilon)
        structure_destination = self.neighbor_visits[destination_nodes] / (torch.max(self.neighbor_visits[destination_nodes], dim=1, keepdim=True)[0] + self.epsilon)
        structure_negative = self.neighbor_visits[negative_nodes] / (torch.max(self.neighbor_visits[negative_nodes], dim=1, keepdim=True)[0] + self.epsilon)

        # Mix the structure information
        source_structure_embedding = self.embed_structure(structure_sources, index_sources)
        destination_structure_embedding = self.embed_structure(structure_destination, index_destination)
        negative_structure_embedding = self.embed_structure(structure_negative, index_negative)

        # Update the structural information
        self.neighbor_visits[source_nodes, :] += index_destination
        self.neighbor_visits[destination_nodes, :] += index_sources
        self.visits[source_nodes] += 1
        self.visits[destination_nodes] += 1

        # Pass it through the mixer
        source_embeddings_index = self.add_degree(source_embeddings_update, source_structure_embedding)
        destination_embeddings_index = self.add_degree(destination_embeddings_update, destination_structure_embedding)
        negative_embeddings_index = self.add_degree(negative_embedding_update, negative_structure_embedding)

        return source_embeddings_index, destination_embeddings_index, negative_embeddings_index
