import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

# from aggregators.gcn import RGCNAggregator
from aggregators.mean import MeanAggregator
from aggregators.self_attention import Transformer
from code_utils import config
from code_utils.config import ARGS


class NeighborhoodEncoder(nn.Module):
    def __init__(self, rel_embds, ent_embds, rel_num, ent_num, emb_dim,
                 seq_encoder, seq_len, snap_encoder, h_dim, out_dim, n_head, dropout, finetune, mask, device):
        super(NeighborhoodEncoder, self).__init__()

        self.device = device

        # Encoder setting
        self.dropout = nn.Dropout(dropout)
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.snap_encoder = snap_encoder
        self.seq_encoder = seq_encoder
        self.seq_len = seq_len
        self.n_head = n_head
        self.mask = mask

        # Embedding setting
        self.emb_dim = emb_dim
        self.ent_num = ent_num
        self.rel_num = rel_num
        self.finetune = finetune

        self.init_encoder(dropout)
        self.init_embeds(rel_emb=rel_embds, ent_emb=ent_embds)

    def init_encoder(self, dropout):
        if self.snap_encoder == 0:  # Mean aggregator
            self.rel_encoder_s = MeanAggregator(h_dim=self.h_dim,
                                                emb_dim=self.emb_dim,
                                                rel_num=self.rel_num,
                                                dropout=dropout,
                                                snap_encoder=self.snap_encoder,
                                                seq_len=self.seq_len,
                                                device=self.device
                                                )
            h_dim = 3 * self.h_dim
        elif self.snap_encoder == 1:  # GCN aggregator
            self.rel_encoder_s = MeanAggregator(h_dim=self.h_dim,
                                                emb_dim=self.emb_dim,
                                                rel_num=self.rel_num,
                                                dropout=dropout,
                                                snap_encoder=self.snap_encoder,
                                                seq_len=self.seq_len,
                                                device=self.device)
            h_dim = 2 * self.h_dim

        if self.seq_encoder == 'gru':
            self.sub_encoder = nn.GRU(h_dim, self.out_dim, batch_first=True)
        else:
            self.sub_encoder = Transformer(d_emb=h_dim,
                                           d_model=h_dim,
                                           n_layers=1,
                                           d_inner=256,
                                           n_head=self.n_head,
                                           device=self.device)
            self.wo = nn.Linear(self.seq_len * self.h_dim * 2, self.out_dim)

        self.ob_encoder = self.sub_encoder
        self.rel_encoder_o = self.rel_encoder_s

    def init_embeds(self, ent_emb, rel_emb):

        self.ent_embeds = nn.Embedding(self.ent_num + 1, self.emb_dim,
                                       padding_idx=self.ent_num)
        self.rel_embeds = nn.Embedding(self.rel_num*2 + 1, self.emb_dim,
                                       padding_idx=self.rel_num)
        if ent_emb is not None:
            ent_emb = np.concatenate((ent_emb, np.zeros((1, ent_emb.shape[1]))))
            rel_emb = np.concatenate((rel_emb, np.zeros((1, rel_emb.shape[1]))))
            self.ent_embeds.weight.data.copy_(torch.from_numpy(ent_emb))
            self.rel_embeds.weight.data.copy_(torch.from_numpy(rel_emb))
            self.ent_embeds.weight.required_grad = self.finetune
            self.rel_embeds.weight.required_grad = self.finetune
            # ent_emb, rel_emb = load_pretrained_emb(self.emb_name, dataset, graph_mode)
            # if ent_emb.shape[0] != self.ent_num or \
            #     rel_emb.shape[0] != self.rel_num or \
            #         ent_emb.shape[1] != self.emb_dim or \
            #         rel_emb.shape[1] != self.emb_dim:
            #             assert "Embedding dimension doesn't match"
            # else:
                # print(np.zeros((ent_emb.shape[1], 1)).shape)
                # ent_emb = np.concatenate((ent_emb, np.zeros((1, ent_emb.shape[1]))))
                # rel_emb = np.concatenate((rel_emb, np.zeros((1, rel_emb.shape[1]))))
                # self.ent_embeds.weight.data.copy_(torch.from_numpy(ent_emb))
                # self.rel_embeds.weight.data.copy_(torch.from_numpy(rel_emb))
                # self.ent_embeds.weight.required_grad = self.finetune
                # self.rel_embeds.weight.required_grad = self.finetune

    def gru_encoder(self, n, s_input, o_input, s_nonzero_hist_lens, o_nonzero_hist_lens):
        s_input = self.dropout(s_input)
        s_packed_input = torch.nn.utils.rnn.pack_padded_sequence(s_input,
                                                          s_nonzero_hist_lens,
                                                          batch_first=True).to(self.device)

        o_input = self.dropout(o_input)
        o_packed_input = torch.nn.utils.rnn.pack_padded_sequence(o_input,
                                                          o_nonzero_hist_lens,
                                                          batch_first=True).to(self.device)
        if s_input is None:
            s_h = torch.zeros((n, self.h_dim), dtype=torch.float32, requires_grad=True)
        else:
            tt, s_h = self.sub_encoder(s_packed_input)
            s_h = s_h.squeeze(dim=0)

        if o_input is None:
            o_h = torch.zeros((n, self.h_dim), dtype=torch.float32, requires_grad=True)
        else:
            tt, o_h = self.ob_encoder(o_packed_input)
            o_h = o_h.squeeze(dim=0)

        return s_h, o_h

    def attention_encoder(self, n, s_input, o_input, s_mask, o_mask):
        if s_input is None:
            s_h = torch.zeros((n, self.h_dim), dtype=torch.float32, requires_grad=True)
        else:
            if self.mask:
                s_h = self.sub_encoder(s_input, s_mask)
            else:
                s_h = self.sub_encoder(s_input)
            s_h = s_h.view(s_h.size()[0], 1, -1).squeeze(1)
            s_h = self.dropout(s_h)
            s_h = self.wo(s_h)
            s_h = F.relu(s_h)

        if o_input is None:
            o_h = torch.zeros((n, self.h_dim), dtype=torch.float32, requires_grad=True)

        else:
            if self.mask:
                print('here')
                o_h = self.ob_encoder(o_input, o_mask)
            else:
                o_h = self.ob_encoder(o_input)
            o_h = o_h.view(o_h.size()[0], 1, -1).squeeze(1)
            o_h = self.dropout(o_h)
            o_h = self.wo(o_h)
            o_h = F.relu(o_h)

        return s_h, o_h

    def forward(self, triplets, s_hist, o_hist, n_support):
        s = triplets[:, 0]
        r = triplets[:, 1]
        o = triplets[:, 2]

        s_hist_len = torch.LongTensor(list(map(len, s_hist[0])))
        s_len, s_idx = s_hist_len.sort(0, descending=True)
        o_hist_len = torch.LongTensor(list(map(len, o_hist[0])))
        o_len, o_idx = o_hist_len.sort(0, descending=True)

        s_input, s_nonzero_hist_lens, s_mask = self.rel_encoder_s(s_hist, s, r, self.ent_embeds, self.rel_embeds)
        o_input, o_nonzero_hist_lens, o_mask = self.rel_encoder_o(o_hist, o, r, self.ent_embeds, self.rel_embeds)

        n = len(o)
        if self.seq_encoder == 'att':
            s_h, o_h = self.attention_encoder(n, s_input, o_input, s_mask, o_mask)
        elif self.seq_encoder == 'gru':
            s_h, o_h = self.gru_encoder(n, s_input, o_input, s_nonzero_hist_lens, o_nonzero_hist_lens)
        else:
            assert 'Encoder Typer Not Found'

        # s_h = self.dropout(s_h)
        # o_h = self.dropout(s_h)
        s_h = torch.cat((s_h, torch.zeros(len(s) - len(s_h), self.out_dim)), dim=0).to(self.device)
        o_h = torch.cat((o_h, torch.zeros(len(o) - len(o_h), self.out_dim)), dim=0).to(self.device)

        _, s_actual_idx = s_idx.sort()
        _, o_actual_idx = o_idx.sort()
        return torch.cat((self.ent_embeds(s), self.ent_embeds(o), o_h[s_actual_idx], s_h[o_actual_idx]), dim=1)
        # return torch.cat((o_h[s_actual_idx], s_h[o_actual_idx]), dim=1)