import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from code_utils.config import ARGS
from code_utils.model import get_sorted_s_r_embeddings


class MeanAggregator(nn.Module):

    def __init__(self, h_dim, emb_dim, dropout, snap_encoder, rel_num, seq_len, device):
        super(MeanAggregator, self).__init__()

        self.device = device
        self.seq_len = seq_len
        self.h_dim = h_dim
        self.emb_dim = emb_dim
        self.dropout = nn.Dropout(dropout)
        self.enc_type = snap_encoder
        self.num_rels = rel_num*2

        # GCN encoder
        # if self.enc_type == 1 or self.enc_type == 3:
        self.gcn_w = nn.Linear(2 * self.emb_dim, self.h_dim)
        self.gcn_b = nn.Parameter(torch.FloatTensor(self.h_dim))
        init.xavier_normal(self.gcn_w.weight)
        init.constant(self.gcn_b, 0)

    def forward(self, s_hist, s, r, ent_embeds, rel_embeds):
        embeds, embeds_split, s_sorted, r_sorted, nonzero_hist_lens, all_lens, all_r, batch_mask = get_sorted_s_r_embeddings(s, r,
                                                                                                                 ent_embeds,
                                                                                                                 rel_embeds,
                                                                                                                 s_hist,
                                                                                                                 self.device,
                                                                                                                 self.num_rels)
        # To get mean vector at each time
        curr = 0
        rows = []
        cols = []
        for i, leng in enumerate(all_lens):
            rows.extend([i] * leng)
            cols.extend(list(range(curr, curr + leng)))
            curr += leng
        rows = torch.LongTensor(rows)
        cols = torch.LongTensor(cols)
        idxes = torch.stack([rows, cols], dim=0)

        mask_tensor = torch.sparse.FloatTensor(idxes, torch.ones(len(rows)))
        if self.enc_type == 1:
            embeds = self.gcn_w(embeds) + self.gcn_b
        # mask_tensor = mask_tensor.cuda()
        embeds_sum = torch.sparse.mm(mask_tensor, embeds)
        embeds_mean = embeds_sum / torch.Tensor(all_lens).view(-1, 1)

        if self.enc_type == 1:
            #     embeds_mean = self.gcn_w(embeds_mean) + self.gcn_b
            embeds_mean = F.relu(embeds_mean)
        embeds_split = torch.split(embeds_mean, nonzero_hist_lens.tolist())
        # print(len(embeds_mean), nonzero_hist_lens)
        s_embed_seq_tensor = torch.zeros(len(nonzero_hist_lens), self.seq_len, 2 * self.h_dim).to(self.device)  # .cuda()

        # Slow!!!
        for i, embeds in enumerate(embeds_split):
            embeds = embeds.to(self.device)
            x = torch.cat(
                (embeds, ent_embeds(s_sorted[i]).repeat(len(embeds), 1)), dim=1)
            s_embed_seq_tensor[i, torch.arange(len(embeds)), :] = x.to(self.device)

        # s_embed_seq_tensor = self.dropout(s_embed_seq_tensor)
        # s_packed_input = torch.nn.utils.rnn.pack_padded_sequence(s_embed_seq_tensor,
        #                                                          nonzero_hist_lens,
        #                                                          batch_first=True)

        return s_embed_seq_tensor, nonzero_hist_lens, batch_mask


