import torch
from torch import nn
import random
from collections import defaultdict
from copy import deepcopy
import numpy as np
from model.time_encoding import TimeEncode

class BaseRNNEmbedding(nn.Module):
  def __init__(self, n_nodes, k_layers, embedding_len, node_feature, history_limit,
               time_encoding="none",
               device="cpu",
               time_encoder=None,
               args=None):
    super(BaseRNNEmbedding, self).__init__()
    self.args=args
    self.n_nodes = n_nodes
    self.k_layers = k_layers # use a multi-layer rnn
    self.embedding_len = embedding_len
    self.history_limit = history_limit
    self.device = device
    print("BaseRNNEmbedding device:", self.device)
    self.time_encoding = time_encoding
    if self.time_encoding == "baseline":
      self.time_len = self.args.time_dim # self.embedding_len
      print("TimeEncode dim in BaseRNNEmbedding module: ",self.time_len)
      self.time_encoder = TimeEncode(dimension=self.time_len) #
    elif self.time_encoding == "copy":
      assert time_encoder != None
      self.time_len = self.embedding_len
      self.time_encoder=time_encoder
    self.__init_traiable_var__()
    self.__init_cache_var__()
    self.last = None


  def __init_traiable_var__(self):
    self.base_embedding = nn.Parameter(torch.zeros((self.n_nodes + 1, self.embedding_len)).to(self.device), requires_grad=True)
    self.final_weights = nn.Parameter(torch.ones((1, 1, self.k_layers + 1)).to(self.device), requires_grad=True)
    self.kernel = []
    if self.args.remove_dst_in_msg:
      msg_length = int(1 * self.embedding_len)
    else:
      msg_length = int(2 * self.embedding_len)
    if self.time_encoding != "none":
      msg_length += int(self.time_len)
    self.tmp_0 = None
    self.tmp_1 = nn.GRUCell(input_size=msg_length, hidden_size=self.embedding_len, device=self.device)
    self.tmp_2 = nn.GRUCell(input_size=msg_length, hidden_size=self.embedding_len, device=self.device)
    self.tmp_3 = nn.GRUCell(input_size=msg_length, hidden_size=self.embedding_len, device=self.device)
    # self.tmp_4 = nn.GRUCell(input_size=msg_length, hidden_size=self.embedding_len, device=self.device)
    if self.args.base_cell=="rnn":
      self.tmp_1 = nn.RNNCell(input_size=msg_length, hidden_size=self.embedding_len, device=self.device)
      self.tmp_2 = nn.RNNCell(input_size=msg_length, hidden_size=self.embedding_len, device=self.device)
      self.tmp_3 = nn.RNNCell(input_size=msg_length, hidden_size=self.embedding_len, device=self.device)
    self.kernel = [self.tmp_0, self.tmp_1, self.tmp_2, self.tmp_3] # up to layer 4 for the multilayer RNN

  def __init_cache_var__(self):
    """ Initializes the memory to all zeros. Called at the start of each epoch. """
    self.dp_cache = nn.Parameter(torch.zeros((self.n_nodes + 1,
                                              self.k_layers + 1,
                                              self.history_limit + 1,
                                              self.embedding_len)).to(self.device), requires_grad=False)
    self.history = np.zeros([self.n_nodes + 1, self.history_limit])
    for i in range(self.n_nodes + 1):
      for j in range(self.history_limit):
        self.history[i, j] = int(0) # 0 is a fake node used for null history
    self.history_timestamp = np.zeros([self.n_nodes + 1, self.history_limit])
    self.dp_cache[:, 0, :, :] = torch.tile(self.base_embedding[:, :].view([self.n_nodes + 1, 1, self.embedding_len]), [1, self.history_limit + 1, 1])
    self.dp_cache[:, :, 0, :] = torch.zeros((self.n_nodes + 1,self.k_layers + 1,self.embedding_len))
    self.dp_cache.detach_()

  def store_history(self, start_nodes, end_nodes, edge_times):
    self.history[start_nodes, :-1] = self.history[start_nodes, 1:]
    self.history[start_nodes, -1] = end_nodes
    self.history_timestamp[start_nodes, :-1] = self.history_timestamp[start_nodes, 1:]
    self.history_timestamp[start_nodes, -1] = edge_times
    if self.args.remove_time:
      self.history_timestamp[start_nodes, -1] = self.history_timestamp[start_nodes, -2] + 1.0
    self.dp_cache[start_nodes, :, 1:, :] = self.dp_cache[start_nodes, :, :-1, :].clone()
    for i in range(self.k_layers + 1):
      self.post_update_cache(start_nodes, i, 0) # use pre-update params to update newest cache. not used in the future
    self.dp_cache.detach_()

  def differentiable_embedding(self, ns, k, self_loop = 0, end_loop = False):
    ns = [int(node) for node in ns] # unknown type conversion
    ########### base case ##########
    if k == 0: # base case for source_message
      self.dp_cache[ns, k, self_loop, :] = self.base_embedding[ns, :]
      if self.args.base_memory_level == 0:
        return self.base_embedding[ns, :]
      else:
        if end_loop: # return the level 1 as the final reduced embedding
          return self.dp_cache[ns, 1, self_loop, :].detach().clone()
        else: # fix bug of msg_not_from_cache useless
          return self.differentiable_embedding(ns, 1, self_loop, end_loop=True)
    if self_loop == self.history_limit: # base case for hidden when reduce_to_base_case
      return self.dp_cache[ns, k, self_loop, :] # the oldest cache version
    ########### recursion ##########
    history_idx = int(self.history_limit - 1 - self_loop)
    history = self.history[ns, history_idx]
    mask = torch.sign(torch.from_numpy(history.astype(np.float32))).to(self.device) # [3*bs]
    ########### source_message ##########
    source_level = k - 1
    if self.args.msg_from_cache:
      if k==1 and self.args.base_memory_level!=0:
        source_level = 1
      source_message = self.dp_cache[self.history[ns, history_idx], source_level, 0, :]
    else:
      source_message = self.differentiable_embedding(self.history[ns, history_idx], source_level, 0, end_loop=end_loop)
    ########### destination_message ##########
    destination_level = k
    if self.args.dst_msg_asynchronous:
      destination_level = k - 1
      if k==1 and self.args.base_memory_level!=0:
        destination_level = 1
    destination_message = self.dp_cache[ns, destination_level, self_loop + 1, :] # same as TGN's
    message = torch.cat([source_message, destination_message] ,dim=1)
    if self.args.remove_dst_in_msg:
      message = source_message
    ########### time_encoding ##########
    if self.time_encoding != "none": # same as TGN time_encoding
      source_time_delta =  torch.from_numpy(self.history_timestamp[ns, history_idx]- \
                                            self.history_timestamp[ns, history_idx - 1]).float().to(self.device)
      source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(ns), -1)
      if self.args.remove_time_in_msg:
        source_time_delta_encoding = torch.zeros_like(source_time_delta_encoding)
      message = torch.cat([message,
                           source_time_delta_encoding],
                           dim=1) # print(source_message.shape)
    ########### hidden state ##########
    hidden_level = k
    if self.args.reduce_to_base_case:
      hidden = self.differentiable_embedding(ns, hidden_level, self_loop + 1, end_loop=end_loop)
    else:
      hidden = self.dp_cache[ns, hidden_level, self_loop + 1, :] # same as TGN memory
    if self.args.diff_msg:
      tensor_res = self.kernel[k](message, hidden)
    else:
      tensor_res = self.kernel[k](message.detach().clone(), hidden)
    tensor_res = torch.einsum("a,ab->ab", 1 - mask, hidden) \
                 + torch.einsum("a,ab->ab", mask, tensor_res)
    # update cache with the newest kernel params (post trained for the last history). will be used in the future
    self.dp_cache[ns, k, self_loop, :] = tensor_res.detach().clone() # same as TGN's memory update
    return tensor_res

  def get_final_embedding(self, nodes, force_memory_module=False):
    old_result = self.base_embedding[nodes,:].detach().clone()
    if self.args.only_base:
      if not force_memory_module:
        return self.base_embedding[nodes,:], old_result
      else:
        return self.differentiable_embedding(nodes, 1, 0), old_result
    self.res = []
    old_result = self.dp_cache[nodes, self.k_layers, 0, :].detach().clone()
    for i in range(self.k_layers + 1):
      # cachetmp = self.dp_cache[nodes, i, 0, :]
      tmp = self.differentiable_embedding(nodes, i, 0)
      self.res.append(tmp)
    if self.args.stack_layers:
      self.res = torch.stack(self.res, dim=1)
      res = torch.matmul(torch.softmax(torch.tile(self.final_weights, [nodes.shape[0], 1, 1]), dim=2), self.res)
      return res.view([nodes.shape[0], -1]), old_result
    else: # return highest level memory
      res = self.res[-1]
      return res, old_result

  def post_update_cache(self, ns, k, sl):
    return None

  def detach_memory(self):
    self.dp_cache.detach_()

  def backup_memory(self):
    return self.dp_cache.data.detach().clone(), \
           deepcopy(self.history), \
           deepcopy(self.history_timestamp), \
           self.base_embedding.data.detach().clone()

  def restore_memory(self, memory_backup):
    self.dp_cache.data, \
    self.history, self.history_timestamp, \
    self.base_embedding.data \
      = memory_backup[0].detach().clone(), \
        deepcopy(memory_backup[1]), \
        deepcopy(memory_backup[2]), \
        memory_backup[3].detach().clone()

  def debug(self):
    pass
