from collections import OrderedDict
import math

import torch
import torch.nn.functional as F
from torch import nn

import params

class DenseLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(DenseLayer, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return F.selu(self.linear(x))


class DenseBlock(nn.Module):
    def __init__(self, num_layers, growth_rate, input_size, output_size):
        super(DenseBlock, self).__init__()

        modules = [DenseLayer(input_size, growth_rate)]
        for i in range(1, num_layers - 1):
            modules.append(DenseLayer(growth_rate * i + input_size, growth_rate))
        modules.append(DenseLayer(growth_rate * (num_layers - 1) + input_size, output_size))
        self.layers = nn.ModuleList(modules)

    def forward(self, x):
        inputs = [x]
        for layer in self.layers:
            output = layer(torch.cat(inputs, dim=-1))
            inputs.append(output)
        return inputs[-1]

class DenseDecoder(nn.Module):
    def __init__(self, input_size, output_size):
        super(DenseDecoder, self).__init__()

        self.dense = DenseBlock(5, params.dense_growth_size, input_size,
                                output_size)

    def forward(self, x):
        x = self.dense(x)
        return x


class DenseEncoder(nn.Module):
    def __init__(self, embedding):
        super(DenseEncoder, self).__init__()

        self.state_len = params.state_len
        self.embedding = embedding
#        self.embedding = nn.Embedding(params.integer_range + 1, params.embedding_size)
#        self.embedding = nn.Linear(params.integer_range + 1, params.embedding_size, bias=False)
#        torch.nn.init.normal_(self.embedding.weight)
#        self.embedding = embedding
        self.var_encoder = nn.Linear(params.max_list_len * params.embedding_size + params.type_vector_len,
#        self.var_encoder = nn.Linear(params.max_list_len * params.embedding_size,
                                     params.var_encoder_size)
        self.dense = DenseBlock(10, params.dense_growth_size, params.var_encoder_size * self.state_len,
                                params.dense_output_size)

    def forward(self, x, typ):
        #print("x1:", x.shape)
        x, num_batches, num_examples = self.embed_state(x, typ)
        #print("x2:", x.shape)
        x = F.selu(self.var_encoder(x))
        #print("x3:", x.shape)
        x = x.view(num_batches, num_examples, -1)
        #print("x4:",x.shape)
        x = self.dense(x)
        #print("x5:", x.shape)
        x = x.mean(dim=1)
        #print("x6:", x.shape)
        return x.view(num_batches, -1)

    def embed_state(self, x, typ):
        types = typ #x[:, :, :, :params.type_vector_len]
        values = x #x[:, :, :, params.type_vector_len:]

        #assert values.size()[1] == params.num_examples, "Invalid num of examples received!"
        assert values.size()[2] == self.state_len, "Example with invalid length received!"
        assert values.size()[3] == params.max_list_len, "Example with invalid length received!"

        num_batches = x.size()[0]
        num_examples = x.size()[1]

#        print(values.shape)
#        if values.dtype == torch.long:
#            values = F.one_hot(values, params.integer_range + 1).float()
##        embedded_values = self.embedding(values.contiguous().view(num_batches, -1))
#        print(values.shape)
        embedded_values = self.embedding(values.contiguous().view(num_batches, -1, params.integer_range + 1))
        embedded_values = embedded_values.view(num_batches, num_examples, self.state_len, -1)
        types = types.contiguous().float()
        return torch.cat((embedded_values, types), dim=-1), num_batches, num_examples
#        return embedded_values, num_batches, num_examples

class DenseQueryEncoder(DenseEncoder):
    '''
    state_len = 3 input + 1 output = 4
    '''
    def __init__(self, embedding):
        super(DenseQueryEncoder, self).__init__(embedding)

        self.state_len = 4
        self.dense = DenseBlock(10, params.dense_growth_size, params.var_encoder_size * self.state_len,
                                params.dense_output_size)

class DensePerIOEncoder(DenseQueryEncoder):
    '''
    state_len = 3 input + 1 output = 4
    per io encoder
    '''
    def __init__(self, embedding):
        super(DensePerIOEncoder, self).__init__(embedding)
    
    def forward(self, x, typ):
        #print("x1:", x.shape)
        x, num_batches, num_examples = self.embed_state(x, typ)
        #print("x2:", x.shape)
        x = F.selu(self.var_encoder(x))
        #print("x3:", x.shape)
        x = x.view(num_batches, num_examples, -1)
        #print("x4:",x.shape)
        x = self.dense(x)
        #print("x5:", x.shape)
        #x, _ = x.max(dim=1)
        #x = x.mean(dim=1)
        #print("x6:", x.shape)
        return x
        #return x.view(num_batches, -1)



class DenseEncoder_ori(nn.Module):
    def __init__(self, embedding):
        super(DenseEncoder_ori, self).__init__()

        self.embedding = nn.Embedding(params.integer_range + 1, params.embedding_size)
        self.var_encoder = nn.Linear(params.max_list_len * params.embedding_size + params.type_vector_len,
                                     params.var_encoder_size)
        self.dense = DenseBlock(10, params.dense_growth_size, params.var_encoder_size * params.state_len,
                                params.dense_output_size)

    def forward(self, x):
        #print("x1:", x.shape)
        x, num_batches, num_examples = self.embed_state(x)
        #print("x2:", x.shape)
        x = F.selu(self.var_encoder(x))
        #print("x3:", x.shape)
        x = x.view(num_batches, num_examples, -1)
        #print("x4:",x.shape)
        x = self.dense(x)
        #print("x5:", x.shape)
        x = x.mean(dim=1)
        #print("x6:", x.shape)
        return x.view(num_batches, -1)

    def embed_state(self, x):
        types = x[:, :, :, :params.type_vector_len]
        values = x[:, :, :, params.type_vector_len:]

        #assert values.size()[1] == params.num_examples, "Invalid num of examples received!"
        assert values.size()[2] == params.state_len, "Example with invalid length received!"
        assert values.size()[3] == params.max_list_len, "Example with invalid length received!"

        num_batches = x.size()[0]
        num_examples = x.size()[1]

        if x.dtype == torch.long:
            x = F.one_hot(x, params.integer_range + 1).float()
        embedded_values = self.embedding(values.contiguous().view(num_batches, -1))
        embedded_values = embedded_values.view(num_batches, num_examples, params.state_len, -1)
        types = types.contiguous().float()
        return torch.cat((embedded_values, types), dim=-1), num_batches, num_examples
class DenseEncoder_notype(nn.Module):
    def __init__(self, embedding):
        super(DenseEncoder_notype, self).__init__()

        self.embedding = nn.Embedding(params.integer_range + 1, params.embedding_size)
#        self.embedding = nn.Linear(params.integer_range + 1, params.embedding_size, bias=False)
#        torch.nn.init.normal_(self.embedding.weight)
        self.var_encoder = nn.Linear(params.max_list_len * params.embedding_size,
                                     params.var_encoder_size)
        self.dense = DenseBlock(10, params.dense_growth_size, params.var_encoder_size * params.state_len,
                                params.dense_output_size)

    def forward(self, x):
        #print("x1:", x.shape)
        x, num_batches, num_examples = self.embed_state(x)
        #print("x2:", x.shape)
        x = F.selu(self.var_encoder(x))
        #print("x3:", x.shape)
        x = x.view(num_batches, num_examples, -1)
        #print("x4:",x.shape)
        x = self.dense(x)
        #print("x5:", x.shape)
        x = x.mean(dim=1)
        #print("x6:", x.shape)
        return x.view(num_batches, -1)

    def embed_state(self, x):
        types = x[:, :, :, :params.type_vector_len]
        values = x[:, :, :, params.type_vector_len:]

        #assert values.size()[1] == params.num_examples, "Invalid num of examples received!"
        assert values.size()[2] == params.state_len, "Example with invalid length received!"
        assert values.size()[3] == params.max_list_len, "Example with invalid length received!"

        num_batches = x.size()[0]
        num_examples = x.size()[1]

#        if values.dtype == torch.long:
#            values = F.one_hot(values, params.integer_range + 1).float()
#        embedded_values = self.embedding(values.contiguous().view(num_batches, -1, params.integer_range+1))
        embedded_values = self.embedding(values.contiguous().view(num_batches, -1))
        embedded_values = embedded_values.view(num_batches, num_examples, params.state_len, -1)
        types = types.contiguous().float()
        return embedded_values, num_batches, num_examples
    
class RNNEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_size,
                 input_dropout_p=0, dropout_p=0, n_layers=1,
                 bidirectional=False, rnn_cell='lstm', variable_lengths=False):
        super(RNNEncoder, self).__init__()

        self.rnn_cell = getattr(nn, rnn_cell.upper())
        self.variable_lengths = variable_lengths

        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.input_dropout = nn.Dropout(p=input_dropout_p)
        self.rnn = self.rnn_cell(hidden_size, hidden_size, n_layers,
                                 batch_first=True, bidirectional=bidirectional,
                                 dropout=dropout_p)

        self.init_weights()

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)

    def forward(self, input_var, input_lengths=None, h0=None):
        embedded = self.embedding(input_var)
        embedded = self.input_dropout(embedded)
        if self.variable_lengths:
            embedded = nn.utils.rnn.pack_padded_sequence(
                    embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False)
        output, hidden = self.rnn(embedded, h0)
        if self.variable_lengths:
            output, _ = nn.utils.rnn.pad_packed_sequence(
                    output, batch_first=True)
        return output, hidden

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes,
                 num_layers=1, dropout_p=0.0):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        layers = []
        for i in range(num_layers):
            idim = hidden_size
            odim = hidden_size
            if i == 0:
                idim = input_size
            if i == num_layers-1:
                odim = num_classes
            fc = nn.Linear(idim, odim)
            fc.weight.data.normal_(0.0,  math.sqrt(2. / idim))
            fc.bias.data.fill_(0)
            layers.append(('fc'+str(i), fc))
            if i != num_layers-1:
                layers.append(('relu'+str(i), nn.ReLU()))
                layers.append(('dropout'+str(i), nn.Dropout(p=dropout_p)))
        self.layers = nn.Sequential(OrderedDict(layers))

    def params_to_train(self):
        return self.layers.parameters()

    def forward(self, x):
        out = self.layers(x)
        return out


class TransformerEncoder(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout=0, 
                 max_length=12+2):
        super().__init__()

        self.src_pad_idx = 0 

        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).cuda()
        
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask

    def forward(self, src):
        batch_size = src.shape[0]
        src_len = src.shape[1]

        src_mask = self.make_src_mask(src) 

        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(src.device)

        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        
        for layer in self.layers:
            src = layer(src, src_mask)
        src = src.mean(1)
        return src
    
class EncoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim,  
                 dropout):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        _src, _ = self.self_attention(src, src, src, src_mask)
        src = self.self_attn_layer_norm(src + self.dropout(_src))
        _src = self.positionwise_feedforward(src)
        src = self.ff_layer_norm(src + self.dropout(_src))
        return src

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).cuda()
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
                
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim=-1)
                
        x = torch.matmul(self.dropout(attention), V)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim)
        x = self.fc_o(x)

        return x, attention

class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)
        return x