import logging
import numpy as np
import torch
from collections import defaultdict

from utils.utils import MergeLayer
from modules.embedding_module import get_embedding_module
from model.time_encoding import TimeEncode

from torch import nn

from modules.lazy_graph_embedding import BaseRNNEmbedding as lge

class LGN(torch.nn.Module):
  def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2,
               n_heads=2, dropout=0.1, use_memory=False,
               memory_update_at_start=True, message_dimension=100,
               memory_dimension=500,
               embedding_module_type="graph_attention", # "graph_attention","identity"
               message_function="mlp",
               mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
               std_time_shift_dst=1, n_neighbors=None, aggregator_type="last",
               memory_updater_type="gru",
               use_destination_embedding_in_message=False,
               use_source_embedding_in_message=False,
               dyrep=False,
               args=None):
    super(LGN, self).__init__()

    self.args=args
    assert self.args.full_lazygraph

    self.n_layers = n_layers
    self.neighbor_finder = neighbor_finder
    self.device = device
    self.logger = logging.getLogger(__name__)

    self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device)
    self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device)

    self.n_node_features = self.node_raw_features.shape[1]
    self.n_nodes = self.node_raw_features.shape[0]
    self.n_edge_features = self.edge_raw_features.shape[1]
    self.embedding_dimension = self.n_node_features
    self.n_neighbors = n_neighbors
    self.embedding_module_type = embedding_module_type
    self.use_destination_embedding_in_message = use_destination_embedding_in_message
    self.use_source_embedding_in_message = use_source_embedding_in_message
    self.dyrep = dyrep

    self.use_memory = use_memory
    print("Time encoder dim:", self.n_node_features)
    self.time_encoder = TimeEncode(dimension=self.n_node_features)
    self.memory = None

    if self.args.full_lazygraph:
      self.lge = lge(self.n_nodes, self.n_layers, self.n_node_features, self.node_raw_features,
                     history_limit=self.args.history_limit,
                     time_encoding=self.args.time_encoding,
                     device=device,
                     time_encoder = self.time_encoder,
                     args=self.args)
    if self.args.unique_time_encoding:
      self.time_encoder = self.lge.time_encoder

    self.mean_time_shift_src = mean_time_shift_src
    self.std_time_shift_src = std_time_shift_src
    self.mean_time_shift_dst = mean_time_shift_dst
    self.std_time_shift_dst = std_time_shift_dst

    if self.use_memory:
      self.memory_dimension = memory_dimension

    self.embedding_module_type = embedding_module_type
    self.embedding_module = get_embedding_module(module_type=embedding_module_type,
                                                 node_features=self.node_raw_features,
                                                 edge_features=self.edge_raw_features,
                                                 memory=self.memory, # not used
                                                 neighbor_finder=self.neighbor_finder,
                                                 time_encoder=self.time_encoder,
                                                 n_layers=self.n_layers,
                                                 n_node_features=self.n_node_features,
                                                 n_edge_features=self.n_edge_features,
                                                 n_time_features=self.n_node_features,
                                                 embedding_dimension=self.embedding_dimension,
                                                 device=self.device,
                                                 n_heads=n_heads, dropout=dropout,
                                                 use_memory=use_memory,
                                                 n_neighbors=self.n_neighbors,
                                                 args=args)
    # MLP to compute probability on an edge given two node embeddings
    self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features,
                                     self.n_node_features,
                                     1)

  def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                                  edge_idxs, n_neighbors=20, update_memory_and_message=True, validation_mode=False):
    """
    Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
    source_nodes [batch_size]: source ids.
    :param destination_nodes [batch_size]: destination ids
    :param negative_nodes [batch_size]: ids of negative sampled destination
    :param edge_times [batch_size]: timestamp of interaction
    :param edge_idxs [batch_size]: index of interaction
    :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
    layer
    :return: Temporal embeddings for sources, destinations and negatives
    """
    if self.args.full_lazygraph:
      l1 = len(source_nodes)
      l2 = len(destination_nodes)
      nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])

      if not self.args.emb_postprocess:
        if self.args.update_cache_at_start: # update cache for all nodes
          tmp_embedding, _ = self.lge.get_final_embedding(np.arange(self.n_nodes + 1))
          nodes_embedding = tmp_embedding[nodes]
        else:
          nodes_embedding, _ = self.lge.get_final_embedding(nodes)
        source_node_embedding = nodes_embedding[:l1, :]
        destination_node_embedding = nodes_embedding[l1:l1 + l2, :]
        negative_node_embedding = nodes_embedding[l1 + l2:, :]

      if self.args.emb_postprocess:
        # tmp_embedding = self.lge.get_embedding_snapshot()
        # tmp_embedding[source_nodes] = source_node_embedding
        # tmp_embedding[destination_nodes] = destination_node_embedding
        # tmp_embedding[negative_nodes] = negative_node_embedding
        tmp_embedding, tmp_embedding_old = self.lge.get_final_embedding(np.arange(self.n_nodes+1)) # important
        # print(tmp_embedding-tmp_embedding_old)
        timestamps = np.concatenate([edge_times, edge_times, edge_times])

        last_update = torch.from_numpy(self.lge.history_timestamp[:,-1]).to(self.device) #-1 -2?
        time_diffs = self.compute_time_diff(edge_times, last_update, source_nodes, destination_nodes,
                                            negative_nodes)
        if self.args.remove_time:
          time_diffs = torch.ones_like(time_diffs)
          timestamps = last_update[nodes].cpu().numpy() + 1.0
        node_embedding = self.embedding_module.compute_embedding(memory=tmp_embedding,
                                                                 source_nodes=nodes,
                                                                 timestamps=timestamps,
                                                                 n_layers=self.n_layers,
                                                                 n_neighbors=n_neighbors,
                                                                 time_diffs=time_diffs,
                                                                 validation_mode= validation_mode,
                                                                 memory_old=tmp_embedding_old) # time_diffs not used
        if self.args.memory_after_emb_postprocess:
          self.lge.dp_cache[:,0,0,:] = self.lge.base_embedding[:, :].detach().clone()
          self.lge.dp_cache[nodes,1,1,:] = node_embedding.detach().clone()
          node_embedding_2 = self.lge.get_final_embedding(nodes, force_memory_module=True)
          # weight = torch.softmax(self.lge.final_weights[0, 0], dim=0)
          # # print(weight)
          # node_embedding = (weight[0] * node_embedding + \
          #                   weight[1] * node_embedding_2)
          node_embedding = 0.7 * node_embedding + 0.3 * node_embedding_2
        source_node_embedding = node_embedding[:l1, :]
        destination_node_embedding = node_embedding[l1:l1 + l2, :]
        negative_node_embedding = node_embedding[l1 + l2:, :]

      # self.lge.update_base_embedding(np.concatenate([source_nodes, destination_nodes], axis=0)) # delete
      self.lge.store_history(np.concatenate([source_nodes, destination_nodes], axis=0),
                             np.concatenate([destination_nodes, source_nodes], axis=0),
                             np.concatenate([edge_times, edge_times], axis=0))
      self.lge.post_update_cache(negative_nodes, 0, 0)
      return source_node_embedding, destination_node_embedding, negative_node_embedding

  def compute_time_diff(self, edge_times, last_update, source_nodes, destination_nodes, negative_nodes):
    source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
      source_nodes].long()
    source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
    destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
      destination_nodes].long()
    destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
    negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
      negative_nodes].long()
    negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst

    time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
                         dim=0)
    return time_diffs

  def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                                 edge_idxs, n_neighbors=20, validation_mode=False):
    """
    Compute probabilities for edges between sources and destination and between sources and
    negatives by first computing temporal embeddings and then feeding them
    into the MLP decoder.
    :param destination_nodes [batch_size]: destination ids
    :param negative_nodes [batch_size]: ids of negative sampled destination
    :param edge_times [batch_size]: timestamp of interaction
    :param edge_idxs [batch_size]: index of interaction
    :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
    layer
    :return: Probabilities for both the positive and negative edges
    """
    n_samples = len(source_nodes)
    # print("[debug 1] Validation mode:", validation_mode)
    source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(
      source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors, validation_mode=validation_mode)

    score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
                                torch.cat([destination_node_embedding,
                                           negative_node_embedding])).squeeze(dim=0)
    pos_score = score[:n_samples]
    neg_score = score[n_samples:]

    return pos_score.sigmoid(), neg_score.sigmoid()

  def set_neighbor_finder(self, neighbor_finder):
    self.neighbor_finder = neighbor_finder
    self.embedding_module.neighbor_finder = neighbor_finder