import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pack_padded_sequence

class Attention(nn.Module):
    """
    Attention Network. https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/caption.py
    """
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        :param encoder_dim: feature size of encoded images
        :param decoder_dim: size of decoder's RNN
        :param attention_dim: size of the attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
        :return: attention weighted encoding, weights
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        return attention_weighted_encoding, alpha

class StateEncoder(nn.Module):

    def __init__(self, embed_dim, encoder_dim):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.grid_fc = nn.Linear(embed_dim, encoder_dim)
        self.onehot_fc = nn.Linear(7, encoder_dim)
        self.inv_fc = nn.Linear(embed_dim, encoder_dim)
        self.goal_fc = nn.Linear(embed_dim, encoder_dim) 

    def forward(self, grid_embedding, grid_onehot, inventory, goal):
        grid = F.relu(self.grid_fc(grid_embedding)) + F.relu(self.onehot_fc(grid_onehot))
        grid = grid.view(-1, 25, self.encoder_dim)
        inventory = inventory.view(-1, 1, 300) # TODO: perhpas change to embed dim?
        inventory = F.relu(self.inv_fc(inventory))
        goal = F.relu(self.goal_fc(goal)).unsqueeze(1)
        encoder_out = torch.cat((grid, inventory, goal), dim=1)
        return encoder_out

class SimpleNetwork(nn.Module):
    def __init__(self, num_actions, embed_dim):
        super(SimpleNetwork, self).__init__()

        self.embed_dim = embed_dim

        self.fc1 = nn.Linear(embed_dim, 150)
        self.fc2 = nn.Linear(7, 20)
        self.fc3 = nn.Linear(170, 90)
        self.fc4 = nn.Linear(embed_dim, 150) 
        self.fc5 = nn.Linear(2250+150, 512)
        self.fc_inv = nn.Linear(embed_dim, 50) 
        self.fc55 = nn.Linear(512+50, 48)

        self.fc6 = nn.Linear(48+32, 48)
        self.fc7 = nn.Linear(48, num_actions)

    def forward(self, grid_embedding, grid_onehot, inventory, goal, hidden):

        #encode features
        c1 = F.relu(self.fc1(grid_embedding))
        c2 = F.relu(self.fc2(grid_onehot))
        c1 = c1.view(-1, 25,150)
        c2 = c2.view(-1, 25,20)
        combined_grids = torch.cat((c1, c2), dim=2)
        c3 = F.relu(self.fc3(combined_grids)) 
        c3 = c3.view(-1, 25*90)
        c4 = F.relu(self.fc4(goal))
        combined_grid_goal = torch.cat((c3, c4), dim=1)
        c6 = F.relu(self.fc5(combined_grid_goal))
        temp_inv = F.relu(self.fc_inv(inventory))
        combined_inventory = torch.cat((c6, temp_inv), dim=1)
        features = F.relu(self.fc55(combined_inventory))

        all_comb = torch.cat((features, hidden), dim=1)
        
        c6 = F.relu(self.fc6(all_comb)) # updated with new embedding size.
        c7 = self.fc7(c6)
        #c8 = F.relu(self.fc8(c7))

        return c7

class AskYourHumanModel(nn.Module):

    def __init__(self, num_actions, vocab, embed_weights, max_seq_len=20, encoder_dim=128, decoder_dim=32, freeze_lang_model=False):
        super().__init__()
        self.vocab = vocab
        vocab_size = len(self.vocab)
        embed_dim = 300

        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size

        self.state_encoder = StateEncoder(embed_dim, encoder_dim)

        self.dropout = 0.5
        self.attention = Attention(encoder_dim, decoder_dim, encoder_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim, vocab_size - 1)
        self.embedding.load_state_dict({'weight': torch.from_numpy(embed_weights)})
        self.embedding.weight.requires_grad = False
        
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary

        self.lang_criterion = nn.CrossEntropyLoss()

        # Now define the action model
        self.action_model = SimpleNetwork(num_actions, embed_dim)

        if freeze_lang_model:
            for p in self.parameters():
                p.requires_grad = False
            for p in self.action_model.parameters():
                p.requires_grad = True

    def forward(self, obs, labels=None, is_target=False):
        grid_embedding, grid_onehot, inventory, goal = obs['grid_embedding'], obs['grid_onehot'], obs['inventory'], obs['goal']
        device = grid_embedding.device
        encoder_out = self.state_encoder(grid_embedding, grid_onehot, inventory, goal) # Should be (B, 25+10+1, encoder_dim)
        
        # Get the LSTM hidden state
        mean_encoder_out = encoder_out.mean(dim=1)
        h, c = self.init_h(mean_encoder_out), self.init_c(mean_encoder_out)
        # Encode the subgoals
        embeddings = self.embedding(obs['subgoal'])
        decode_lengths = obs['decode_lengths']
        max_decode_length = max(decode_lengths)
        batch_size, num_tokens, _ = encoder_out.shape

        predictions = torch.zeros(batch_size, max_decode_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_tokens).to(device)
        hiddens = h.clone()

        for t in range(max_decode_length):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            hiddens[:batch_size_t] = h.clone() ## ADDED!
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        # Get the final hidden state.
        actions = self.action_model(grid_embedding, grid_onehot, inventory, goal, hiddens)

        if not labels is None:
            # Compute the language loss here
            predictions = pack_padded_sequence(predictions, decode_lengths, batch_first=True)[0]
            targets = pack_padded_sequence(labels, decode_lengths, batch_first=True)[0]
            lang_loss = self.lang_criterion(predictions, targets)
            lang_loss += 0.5 * ((1. - alphas.sum(dim=1))**2).mean()
        else:
            lang_loss = None
        return actions, lang_loss, None

    def predict(self, obs, deterministic=True, history=None):
        assert deterministic, "only deterministic is currently supported"
        grid_embedding, grid_onehot, inventory, goal = obs['grid_embedding'], obs['grid_onehot'], obs['inventory'], obs['goal']
        device = grid_embedding.device
        encoder_out = self.state_encoder(grid_embedding, grid_onehot, inventory, goal) # Should be (B, 25+10+1, encoder_dim)

        mean_encoder_out = encoder_out.mean(dim=1)
        h, c = self.init_h(mean_encoder_out), self.init_c(mean_encoder_out)
        hiddens = h.clone()

        k = encoder_out.size(0)
        k_prev_words = torch.LongTensor([[self.vocab.word2idx['<start>']]] * k).to(device)  # (k, 1)
        seqs = k_prev_words
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)
        complete_seqs = [[self.vocab.word2idx['<start>']] for i in range(k)]
        incomplete_inds = [i for i in range(k)] # used to keep track of original index in complete_seqs

        step = 1
        while True:
            embeddings = self.embedding(k_prev_words).squeeze(1)
            awe, alpha = self.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)
            gate = self.sigmoid(self.f_beta(h))  # gating scalar, (s, encoder_dim)
            awe = gate * awe
            h, c = self.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)
            hiddens[incomplete_inds] = h.clone()
            scores = self.fc(h)  # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)

            values, indices = scores.max(dim=1) 
            assert(indices.size(0) == len(incomplete_inds))

            temp = []
            for i in range(indices.size(0)-1, -1, -1):
                complete_seqs[incomplete_inds[i]].append(indices.data.tolist()[i])
                if indices[i] == self.vocab.word2idx['<end>']:
                    del incomplete_inds[i]
                    #incomplete_inds.remove(i)
                else:
                    #not finished
                    temp.append(i)

            if len(incomplete_inds) == 0:
                break

            #subset the ones that aren't finished.
            h = h[temp]
            c = c[temp]
            encoder_out = encoder_out[temp]
            k_prev_words = indices[temp].unsqueeze(1)

            # Break if things have been going on too long
            if step > 20:
                break

            step += 1

        actions = self.action_model(grid_embedding, grid_onehot, inventory, goal, hiddens)
        _, action = actions[0].max(0) 
        return action.item()


class AskYourHumanMLP(nn.Module):

    def __init__(self, num_actions, vocab, embed_weights):
        super().__init__()
        embed_dim = 300
        self.embed_dim = embed_dim

        self.fc1 = nn.Linear(embed_dim, 150)
        self.fc2 = nn.Linear(7, 20)
        self.fc3 = nn.Linear(170, 90)
        self.fc4 = nn.Linear(embed_dim, 150) 
        self.fc5 = nn.Linear(2250+150, 512)
        self.fc_inv = nn.Linear(embed_dim, 50) 
        self.fc55 = nn.Linear(512+50, 48)

        self.fc6 = nn.Linear(48, 48) # remove the hidden part from here.
        self.fc7 = nn.Linear(48, num_actions)

    def forward(self, obs, labels=None, is_target=False):
        grid_embedding, grid_onehot, inventory, goal = obs['grid_embedding'], obs['grid_onehot'], obs['inventory'], obs['goal']
        #encode features
        c1 = F.relu(self.fc1(grid_embedding))
        c2 = F.relu(self.fc2(grid_onehot))
        c1 = c1.view(-1, 25,150)
        c2 = c2.view(-1, 25,20)
        combined_grids = torch.cat((c1, c2), dim=2)
        c3 = F.relu(self.fc3(combined_grids)) 
        c3 = c3.view(-1, 25*90)
        c4 = F.relu(self.fc4(goal))
        combined_grid_goal = torch.cat((c3, c4), dim=1)
        c6 = F.relu(self.fc5(combined_grid_goal))
        temp_inv = F.relu(self.fc_inv(inventory))
        combined_inventory = torch.cat((c6, temp_inv), dim=1)
        features = F.relu(self.fc55(combined_inventory))

        c6 = F.relu(self.fc6(features)) # updated with new embedding size.
        actions = self.fc7(c6)
        #c8 = F.relu(self.fc8(c7))

        return actions, None, None

    def predict(self, obs, deterministic=True, history=None):
        assert deterministic, "Undeterministic sampling has not been implemented."
        actions, _, _ = self(obs)
        _, action = actions[0].max(0) 
        return action.item()

class InverseModelStateEncoder(nn.Module):

    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.onehot_fc = nn.Linear(7, embed_dim)

    def forward(self, obs):
        grid_embedding, grid_onehot, inventory, goal = obs['grid_embedding'], obs['grid_onehot'], obs['inventory'], obs['goal']
        onehot = self.onehot_fc(grid_onehot).view(-1, 25, self.embed_dim)
        grid_embedding = grid_embedding.view(-1, 25, self.embed_dim)
        inventory = inventory.view(-1, 1, self.embed_dim)
        goal = goal.view(-1, 1, self.embed_dim)
        encoder_out = torch.cat((grid_embedding, onehot, inventory, goal), dim=1)
        return encoder_out

class AskYourHumanMLPInverseModel(nn.Module):

    def __init__(self, num_actions, vocab, embed_weights):
        super().__init__()
        embed_dim = 300
        self.state_encoder = nn.Sequential(
            InverseModelStateEncoder(300),
            nn.Linear(300, 48),
        )
        self.net = nn.Sequential(
            nn.Linear(2496, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, obs, next_obs):
        obs, next_obs = self.state_encoder(obs), self.state_encoder(next_obs)
        x = F.relu(next_obs - obs) # Take the difference and activate
        b, s, c = x.shape
        x = x.view(b, s*c)
        return self.net(x)
