import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


from torch.autograd import Variable

class LayerNormalization(nn.Module):
    ''' Layer normalization module '''

    def __init__(self, d_hid, eps=1e-3):
        super(LayerNormalization, self).__init__()

        self.eps = eps
        self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
        self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)

    def forward(self, z):
        if z.size(1) == 1:
            return z

        mu = torch.mean(z, keepdim=True, dim=-1)
        sigma = torch.std(z, keepdim=True, dim=-1)
        ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
        ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)

        return ln_out


class SupportEncoder(nn.Module):
    """docstring for SupportEncoder"""
    def __init__(self, d_model, d_inner, dropout=0.1):
        super(SupportEncoder, self).__init__()
        self.proj1 = nn.Linear(d_model, d_inner)
        self.proj2 = nn.Linear(d_inner, d_model)
        self.layer_norm = LayerNormalization(d_model)

        init.xavier_normal_(self.proj1.weight)
        init.xavier_normal_(self.proj2.weight)

        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        output = self.relu(self.proj1(x))
        output = self.dropout(self.proj2(output))
        return self.layer_norm(output + residual)


class QueryEncoder(nn.Module):
    """docstring for QueryEncoder"""
    def __init__(self, input_dim, process_step=4, device=None):
        super(QueryEncoder, self).__init__()
        self.device=device
        self.input_dim = input_dim
        self.process_step = process_step
        self.process = nn.LSTMCell(input_dim, 2*input_dim)

        # initialize the hidden states, TODO: try to train the initial state
        # self.h0 = Variable(torch.zeros(self.batch_size, 2*input_dim)).cuda()
        # self.c0 = Variable(torch.zeros(self.batch_size, 2*input_dim)).cuda()

    def forward(self, support, query):
        '''
        support: (few, support_dim)
        query: (batch_size, query_dim)
        support_dim = query_dim

        return:
        (batch_size, query_dim)
        '''
        assert support.size()[1] == query.size()[1]

        if self.process_step == 0:
            return query

        batch_size = query.size()[0]
        h_r = Variable(torch.zeros(batch_size, 2*self.input_dim)).to(self.device)
        c = Variable(torch.zeros(batch_size, 2*self.input_dim)).to(self.device)
        for step in range(self.process_step):
            h_r_, c = self.process(query, (h_r, c))
            h = query + h_r_[:,:self.input_dim] # (batch_size, query_dim)
            attn = F.softmax(torch.matmul(h, support.t()), dim=1)
            r = torch.matmul(attn, support) # (batch_size, support_dim)
            h_r = torch.cat((h, r), dim=1)

        # return h_r_[:, :self.input_dim]
        return h


class Matcher(nn.Module):
    """
    Matching metric based on KB Embeddings
    """
    def __init__(self, encoder, h_dim, out_dim, steps, dropout, device):
        super(Matcher, self).__init__()

        self.encoder = encoder
        self.dropout = nn.Dropout()
        self.device = device

        d_model = (h_dim + out_dim)*2
        # d_model = out_dim * 2

        self.support_encoder = SupportEncoder(d_model, 2*d_model, dropout)
        self.query_encoder = QueryEncoder(d_model, steps, device=device)

    def loss(self, sample):
        '''
        query: (batch_size, 2)
        support: (few, 2)
        return: (batch_size, )
        '''

        xs = Variable(torch.from_numpy(sample['xs']['triplets'])).to(self.device)  # support
        xq = Variable(torch.from_numpy(sample['xq']['triplets'])).to(self.device)  # query

        xs_sub_hist = sample['xs']['s_hist']
        xq_sub_hist = sample['xq']['s_hist']

        xs_obj_hist = sample['xs']['o_hist']
        xq_obj_hist = sample['xq']['o_hist']

        n_support = xs.size(0)

        x = torch.cat([xs, xq], 0)
        s_hist = [xs_sub_hist[0] + xq_sub_hist[0], xs_sub_hist[1] + xq_sub_hist[1]]
        o_hist = [xs_obj_hist[0] + xq_obj_hist[0], xs_obj_hist[1] + xq_obj_hist[1]]
        z = self.encoder.forward(x, s_hist, o_hist, n_support)
        z_dim = z.size(-1)

        support = z[:n_support]
        query = z[n_support:]

        # print(support.size())

        support_g = self.support_encoder(support)
        query_g = self.support_encoder(query)

        support_g = torch.mean(support_g, dim=0, keepdim=True)
        query_f = self.query_encoder(support_g, query_g) # 128 * 100

        # cosine similarity
        matching_scores = torch.matmul(query_f, support_g.t()).squeeze()
        return matching_scores

    def forward_(self, query_meta, support_meta):
        query_left_connections, query_left_degrees, query_right_connections, query_right_degrees = query_meta
        support_left_connections, support_left_degrees, support_right_connections, support_right_degrees = support_meta

        query_left = self.neighbor_encoder(query_left_connections, query_left_degrees)
        query_right = self.neighbor_encoder(query_right_connections, query_right_degrees)
        support_left = self.neighbor_encoder(support_left_connections, support_left_degrees)
        support_right = self.neighbor_encoder(support_right_connections, support_right_degrees)

        query = torch.cat((query_left, query_right), dim=-1) # tanh
        support = torch.cat((support_left, support_right), dim=-1) # tanh

        support_expand = support.expand_as(query)

        distances = F.sigmoid(self.siamese(torch.abs(support_expand - query))).squeeze()
        return distances