import torch
from torch import nn
import numpy as np
import math

from model.temporal_attention import TemporalAttentionLayer
from copy import deepcopy

class EmbeddingModule(nn.Module):
  def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
               n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
               dropout, args=None):
    super(EmbeddingModule, self).__init__()
    self.node_features = node_features
    self.edge_features = edge_features
    # self.memory = memory
    self.neighbor_finder = neighbor_finder
    self.time_encoder = time_encoder
    self.n_layers = n_layers
    self.n_node_features = n_node_features
    self.n_edge_features = n_edge_features
    self.n_time_features = n_time_features
    self.dropout = dropout
    self.embedding_dimension = embedding_dimension
    self.device = device
    self.args=args
    self.revision_decay = 0.0

  def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                        use_time_proj=True):
    pass
  def debug(self):
    pass

class GraphMemoryCacheEmbedding(EmbeddingModule):
  def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
               n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
               n_heads=2, dropout=0.1, use_memory=True, args=None):
    super(GraphMemoryCacheEmbedding, self).__init__(node_features, edge_features, memory,
                                         neighbor_finder, time_encoder, n_layers,
                                         n_node_features, n_edge_features, n_time_features,
                                         embedding_dimension, device, dropout, args)
    self.use_memory = use_memory
    self.device = device
    self.revision_decay = 1.0
    # memory module
    self.embedding_len = n_node_features
    if self.args.remove_dst_in_msg:
      msg_length = int(1 * self.embedding_len)
    else:
      msg_length = int(2 * self.embedding_len)
    self.add_edge_feature_in_msg, self.add_node_feature_in_msg = False, False
    if self.add_edge_feature_in_msg:
      msg_length += edge_features.shape[1]
    if self.add_node_feature_in_msg:
      msg_length += node_features.shape[1]
    # msg_length += self.embedding_len
    print("MSG input raw dimension in the MSG function (EMB): ",msg_length)
    self.MSG_MLP = torch.nn.ModuleList([
      nn.GRUCell(input_size=msg_length, hidden_size=self.embedding_len, device=self.device)])
    self.STATE_UPDATER_list =  ["None", "LINEAR","GRU","RNN"]
    self.STATE_decay_rate = 0.1
    if self.args.n_layer==1:
      self.STATE_UPDATER = self.STATE_UPDATER_list[0]
    elif self.args.n_layer>1:
      self.STATE_UPDATER = self.STATE_UPDATER_list[2]
    if self.STATE_UPDATER=="LINEAR":
      print(self.STATE_decay_rate)
    if "GRU" in self.STATE_UPDATER:
      self.State_models = torch.nn.ModuleList([
        nn.GRUCell(input_size=self.embedding_len, hidden_size=self.embedding_len, device=self.device)])
    elif "RNN" in self.STATE_UPDATER:
      self.State_models = torch.nn.ModuleList([
        nn.RNNCell(input_size=self.embedding_len, hidden_size=self.embedding_len, device=self.device)])

    self.layer_weights = nn.Parameter(torch.Tensor([0.1,0.9]).to(self.device), requires_grad=True)

    self.n_nodes = node_features.shape[0]
    # self.history_limit = self.args.history_limit
    self.dp_cache = nn.Parameter(torch.zeros((self.n_nodes,
                                              n_layers + 1,
                                              self.embedding_len)).to(self.device), requires_grad=False)
    print(self.dp_cache.shape)
    self.cache_var = np.zeros([n_layers,10000])
    self.current_iter = 0
    self.sync_ratio, self.sync_weight = [], []
    # ADDING DIFFUSION
    class NormalLinear(nn.Linear):
      def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.normal_(0, stdv)
        if self.bias is not None:
          self.bias.data.normal_(0, stdv)
    self.diffusion_embedding_layer = NormalLinear(1, self.n_node_features)
  def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20,
                        time_diffs=None,
                        use_time_proj=True,
                        validation_mode=False,
                        leave_out_sets=None,
                        memory_old=None,
                        UPDATE_CACHE = True,
                        TIME_DIFF_IN_NEIGHBORS=False,  # True performs worse
                        ONLY_ASYNC_NODES = False,
                        REVISION_COMPUTATION=False, # mode of computation
                        LINEAR_DECAY=False,
                        USE_GRU_IN_EMB = True,
                        DECAY_GNN = False,
                        FORCE_ZERO_GNN = False,  # False # True # validation_mode
                        DIFF=False,
                        ):
    """Recursive implementation of curr_layers temporal graph attention layers.
    src_idx_l [batch_size]: users / items input ids.
    cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
    curr_layers [scalar]: number of temporal convolutional layers to stack.
    num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
    """
    assert (n_layers >= 0)

    source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
    timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)

    # query node always has the start time -> time span == 0
    source_nodes_time_embedding = self.time_encoder(torch.zeros_like(
      timestamps_torch))

    source_node_features = self.node_features[source_nodes_torch, :]

    if self.use_memory: # base plus feature. is zero in ml100k, wiki
      source_node_features = memory[source_nodes, :] + source_node_features

    if n_layers == 0: # base case, note that 0-layer cache no update
      return source_node_features
    else: # recursion
      full_source_nodes = deepcopy(source_nodes)
      threshold = 1 - self.args.rollout_ratio # drop-out rate
      ########## cache ###########
      if not validation_mode:
        source_nodes, cache_nodes, id_list, cache_id_list = \
          [source_nodes[0]], [], [0], [] # at least one node to roll-out
        for i, s in enumerate(full_source_nodes[1:]):
          if np.random.uniform()>threshold:
            source_nodes.append(s)
            id_list.append(i+1)
          else:
            cache_nodes.append(s)
            cache_id_list.append(i+1)
        source_nodes, cache_nodes = np.array(source_nodes), np.array(cache_nodes)
      else:
        id_list = np.arange(len(source_nodes))
        cache_id_list = []
      if leave_out_sets==None:
        leave_out_sets = [[] for i in range(len(id_list))]
      neighbors, edge_idxs, edge_times, leave_out_mask = self.neighbor_finder.get_temporal_neighbor(
        full_source_nodes[id_list],# source_nodes,
        timestamps,
        n_neighbors=n_neighbors,
        leave_out=self.args.leave_out,
        leave_out_sets=leave_out_sets)

      neighbor_leave_out_sets=[]
      for i in range(int(len(id_list)*n_neighbors)):
        j = i//n_neighbors #
        neighbor_leave_out_sets.append(leave_out_sets[j])
        neighbor_leave_out_sets[i].append(full_source_nodes[id_list[j]])
        neighbor_leave_out_sets[i] = list(set(neighbor_leave_out_sets[i]))

      if ONLY_ASYNC_NODES or TIME_DIFF_IN_NEIGHBORS:
        timestamps_neighbors = np.repeat(timestamps, n_neighbors)
        # print(timestamps_neighbors.shape) # 6000
        _1, _2,  neighbors_last_update_times = self.neighbor_finder.get_temporal_neighbor(
          np.array(neighbors).reshape([-1]),# source_nodes,
          timestamps_neighbors,
          n_neighbors=1)
        neighbors_last_update_times = neighbors_last_update_times.reshape([-1,n_neighbors])
        # if not validation_mode:
        sync_node = 0
        for i in range(neighbors_last_update_times.shape[0]):
          for j in range(int(1)):
            if neighbors[i, j] != 0 and edge_times[i, j] >= neighbors_last_update_times[i, j]:
              sync_node += 1
              if ONLY_ASYNC_NODES:
                neighbors[i, j]=0
        self.sync_ratio.append(sync_node/(sync_node+np.sum(np.sign(neighbors))))
        self.sync_weight.append(sync_node + np.sum(np.sign(neighbors)))
      else:
        self.sync_ratio.append(1.0)
        self.sync_weight.append(1.0)

      neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)
      edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)

      edge_deltas = timestamps[:, np.newaxis] - edge_times

      assert not self.args.change_time_to_seq
      if self.args.change_time_to_seq:
        replaced_edge_deltas = np.tile(np.arange(n_neighbors)[::-1], (len(source_nodes), 1)) + 1.0
        for i in range(len(source_nodes)):
          for j in range(n_neighbors - 1):
            if edge_deltas[i, -j - 1] == edge_deltas[i, -j - 2]:
              replaced_edge_deltas[i, -j - 2] = replaced_edge_deltas[i, -j - 1]
        edge_deltas = replaced_edge_deltas

      # print(edge_deltas)
      edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device) # % 1000000

      neighbors = neighbors.flatten()
      if memory_old != None:
        _memory_old = memory_old
      else:
        _memory_old = 0.0
      if not self.args.wo_diff:
        neighbor_revisions = self.compute_embedding(memory-_memory_old,
                                                     neighbors,
                                                     np.repeat(timestamps, n_neighbors),
                                                     n_layers=n_layers - 1,
                                                     n_neighbors=n_neighbors,
                                                     validation_mode=validation_mode,
                                                     leave_out_sets=neighbor_leave_out_sets,
                                                     USE_GRU_IN_EMB=False)
      neighbor_embeddings = self.compute_embedding(memory,
                                                   neighbors,
                                                   np.repeat(timestamps, n_neighbors),
                                                   n_layers=n_layers - 1,
                                                   n_neighbors=n_neighbors,
                                                   validation_mode=validation_mode,
                                                   leave_out_sets=neighbor_leave_out_sets,
                                                   USE_GRU_IN_EMB=True)
      # neighbor_embeddings = torch.zeros_like(neighbor_revisions)
      if self.args.wo_diff:
        neighbor_revisions = torch.zeros_like(neighbor_embeddings)
      neighbor_embeddings += neighbor_revisions
      min_revision_decay = 0.00
      if LINEAR_DECAY:
        self.revision_decay -= 0.0001
      else:
        self.revision_decay = min_revision_decay + (self.revision_decay-min_revision_decay)*0.999
      if self.revision_decay<=min_revision_decay:
        self.revision_decay=min_revision_decay
      effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
      neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)

      edge_time_embeddings = self.time_encoder(edge_deltas_torch)

      assert not self.args.delete_time_or_seq
      if self.args.delete_time_or_seq:
        edge_time_embeddings = torch.zeros_like(edge_time_embeddings)
      if self.args.leave_out:
        edge_time_embeddings[:,:,-1] = torch.from_numpy(leave_out_mask[:,:,0]).to(self.device)

      edge_features = self.edge_features[edge_idxs, :]

      mask = neighbors_torch == 0
      source_embedding = self.aggregate(n_layers, source_node_features[id_list],
                                        source_nodes_time_embedding[id_list],
                                        neighbor_embeddings,
                                        edge_time_embeddings,
                                        edge_features,
                                        mask)
      if not DECAY_GNN:
        self.revision_decay = 1.0

      if USE_GRU_IN_EMB:
        if self.add_node_feature_in_msg: # added after aggregate
          source_embedding = torch.cat([source_embedding,
                                        source_node_features[id_list]], dim=1)

        source_message = source_embedding
        destination_message = source_node_features[id_list]
        message = torch.cat([self.revision_decay * source_message,
                             self.revision_decay * destination_message], dim=1)
        if DECAY_GNN and FORCE_ZERO_GNN:
          message = 0.0 * message
        if self.args.remove_dst_in_msg:
          message = source_message
        hidden = self.compute_embedding(memory,
                                       source_nodes,
                                       timestamps,
                                       n_layers=n_layers - 1,
                                       n_neighbors=n_neighbors,
                                       validation_mode=validation_mode)
        old_state = self.dp_cache[full_source_nodes[id_list], n_layers]
        source_embedding = self.MSG_MLP[0](message, hidden) # Msg
        if self.STATE_UPDATER == "None":
          pass
        elif self.STATE_UPDATER=="LINEAR":
          source_embedding = self.STATE_decay_rate * old_state + (1-self.STATE_decay_rate) * source_embedding # linear
        elif self.STATE_UPDATER=="GRU" or self.STATE_UPDATER=="RNN":
          source_embedding = self.State_models[0](source_embedding, old_state) # gru
      if self.args.weighted_gnn and n_layers==self.args.n_layer:
        source_embedding =  1/torch.sum(self.layer_weights)*(self.layer_weights[0] * source_node_features + \
                                                self.layer_weights[1] * source_embedding)
      if UPDATE_CACHE:
        self.update_cache(full_source_nodes[id_list],# source_nodes,
                          n_layers,source_embedding)

      result = source_embedding
      if DIFF:
        result = source_embedding-old_state
      if not validation_mode:
        result = source_node_features
        result[id_list, :] = source_embedding
        result[cache_id_list,:] = self.dp_cache[cache_nodes,n_layers].clone().detach()
      return result
  def update_cache(self, source_nodes, n_layers, source_embedding):
    var = (torch.mean(torch.square(self.dp_cache[source_nodes, n_layers] - source_embedding))/ \
           torch.mean(torch.square(self.dp_cache[source_nodes, n_layers])+1e-7)).detach().cpu().numpy()
    var = int(100000*var)
    # print(var,end=";")
    self.cache_var[n_layers-1 ,self.current_iter]=var
    self.current_iter+=1
    self.dp_cache[source_nodes,n_layers] = source_embedding.clone().detach()
    return True
  def detach_memory(self):
    self.dp_cache.detach_()
  def backup_memory(self):
    return self.dp_cache.data.detach().clone()
  def restore_memory(self, memory_backup):
    self.dp_cache.data = memory_backup.detach().clone()
  def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,
                neighbor_embeddings,
                edge_time_embeddings, edge_features, mask):
    return None
  def debug(self):
    pass
class GraphMemoryGCNCacheEmbedding(GraphMemoryCacheEmbedding):
  def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                 n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                 n_heads=2, dropout=0.1, use_memory=True, args=None):
    super(GraphMemoryGCNCacheEmbedding, self).__init__(node_features, edge_features, memory,
                                                      neighbor_finder, time_encoder, n_layers,
                                                      n_node_features, n_edge_features,
                                                      n_time_features,
                                                      embedding_dimension, device,
                                                      n_heads, dropout,
                                                      use_memory, args)
  def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
                neighbor_embeddings,
                edge_time_embeddings, edge_features, mask):
    binary_mask = 1 - mask.cpu().numpy().astype(int)
    num_degrees = np.sum(binary_mask,axis=1)
    coefficient = np.zeros([len(num_degrees)])
    for i,degree in enumerate(num_degrees):
      if degree>0:
        coefficient[i]=1/np.sqrt(degree)
    coefficient = torch.from_numpy(coefficient.astype(np.float32)).to(self.device)
    neighbors_features = neighbor_embeddings
    if self.add_edge_feature_in_msg: # add in aggregate
      neighbors_features = torch.cat([neighbor_embeddings, edge_features], dim=2)
    source_embedding = torch.einsum("abc,a->ac", neighbors_features, coefficient)
    return source_embedding
class GraphMemoryAttentionCacheEmbedding(GraphMemoryCacheEmbedding):
  def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                 n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                 n_heads=2, dropout=0.1, use_memory=True, args=None):
    super(GraphMemoryAttentionCacheEmbedding, self).__init__(node_features, edge_features, memory,
                                                      neighbor_finder, time_encoder, n_layers,
                                                      n_node_features, n_edge_features,
                                                      n_time_features,
                                                      embedding_dimension, device,
                                                      n_heads, dropout,
                                                      use_memory, args)
    self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(
            n_node_features=n_node_features,
            n_neighbors_features=n_node_features,
            n_edge_features=n_edge_features,
            time_dim=n_time_features,
            n_head=n_heads,
            dropout=dropout,
            output_dimension=n_node_features)
            for _ in range(n_layers)])
  def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
                neighbor_embeddings,
                edge_time_embeddings, edge_features, mask):
    attention_model = self.attention_models[n_layer - 1]
    source_embedding, _ = attention_model(source_node_features,
                                            source_nodes_time_embedding,
                                            neighbor_embeddings,
                                            edge_time_embeddings,
                                            edge_features,
                                            mask)
    if self.add_edge_feature_in_msg:
      binary_mask = 1 - mask.cpu().numpy().astype(int)
      num_degrees = np.sum(binary_mask, axis=1)
      coefficient = np.zeros([len(num_degrees)])
      for i, degree in enumerate(num_degrees):
        if degree > 0:
          coefficient[i] = 1 / np.sqrt(degree)
      coefficient = torch.from_numpy(coefficient.astype(np.float32)).to(self.device)
      source_embedding = torch.cat([source_embedding,
        torch.einsum("abc,a->ac", edge_features, coefficient)],dim=1)
    return source_embedding


def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,
                         time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,
                         embedding_dimension, device,
                         n_heads=2, dropout=0.1, n_neighbors=None,
                         use_memory=True,
                         args=None):
  if module_type == "memory_gcn_cache":
    return GraphMemoryGCNCacheEmbedding(node_features=node_features,
                             edge_features=edge_features,
                             memory=memory,
                             neighbor_finder=neighbor_finder,
                             time_encoder=time_encoder,
                             n_layers=n_layers,
                             n_node_features=n_node_features,
                             n_edge_features=n_edge_features,
                             n_time_features=n_time_features,
                             embedding_dimension=embedding_dimension,
                             device=device,
                             dropout=dropout,
                             args=args)
  elif module_type == "memory_attention_cache":
    return GraphMemoryAttentionCacheEmbedding(node_features=node_features,
                             edge_features=edge_features,
                             memory=memory,
                             neighbor_finder=neighbor_finder,
                             time_encoder=time_encoder,
                             n_layers=n_layers,
                             n_node_features=n_node_features,
                             n_edge_features=n_edge_features,
                             n_time_features=n_time_features,
                             embedding_dimension=embedding_dimension,
                             device=device,
                             dropout=dropout,
                             args=args)
  else:
    raise ValueError("Embedding Module {} not supported".format(module_type))

