import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import rnn


def duplicate(output, mask, lens, act_sizes):
    """
    Duplicate the output based on the action sizes.
    """
    output = torch.cat([output[i:i+1].repeat(j, 1, 1) for i, j in enumerate(act_sizes)], dim=0)
    mask = torch.cat([mask[i:i+1].repeat(j, 1) for i, j in enumerate(act_sizes)], dim=0)
    lens = list(itertools.chain.from_iterable([lens[i:i+1] * j for i, j in enumerate(act_sizes)]))
    return output, mask, lens


def get_aggregated(output, lens, method):
    """
    Get the aggregated hidden state of the encoder.
    B x D
    """
    if method == 'mean':
        return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0)
    elif method == 'last':
        return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0)
    elif method == 'first':
        return output[:, 0, :]


class EncoderRNN(nn.Module):
    def __init__(self, input_size, num_units, nlayers, concat,
                 bidir, layernorm, return_last):
        super().__init__()
        self.layernorm = (layernorm == 'layer')
        if layernorm:
            self.norm = nn.LayerNorm(input_size)

        self.rnns = []
        for i in range(nlayers):
            if i == 0:
                input_size_ = input_size
                output_size_ = num_units
            else:
                input_size_ = num_units if not bidir else num_units * 2
                output_size_ = num_units
            self.rnns.append(
                nn.GRU(input_size_, output_size_, 1,
                       bidirectional=bidir, batch_first=True))

        self.rnns = nn.ModuleList(self.rnns)
        self.init_hidden = nn.ParameterList(
            [nn.Parameter(
                torch.zeros(size=(2 if bidir else 1, 1, num_units)),
                requires_grad=True) for _ in range(nlayers)])
        self.concat = concat
        self.nlayers = nlayers
        self.return_last = return_last

        self.reset_parameters()

    def reset_parameters(self):
        with torch.no_grad():
            for rnn_layer in self.rnns:
                for name, p in rnn_layer.named_parameters():
                    if 'weight_ih' in name:
                        torch.nn.init.xavier_uniform_(p.data)
                    elif 'weight_hh' in name:
                        torch.nn.init.orthogonal_(p.data)
                    elif 'bias' in name:
                        p.data.fill_(0.0)
                    else:
                        p.data.normal_(std=0.1)

    def get_init(self, bsz, i):
        return self.init_hidden[i].expand(-1, bsz, -1).contiguous()

    def forward(self, inputs, input_lengths=None):
        bsz, slen = inputs.size(0), inputs.size(1)
        if self.layernorm:
            inputs = self.norm(inputs)
        output = inputs
        outputs = []
        lens = 0
        if input_lengths is not None:
            lens = input_lengths  # .data.cpu().numpy()
        for i in range(self.nlayers):
            hidden = self.get_init(bsz, i)
            # output = self.dropout(output)
            if input_lengths is not None:
                output = rnn.pack_padded_sequence(output, lens,
                                                  batch_first=True,
                                                  enforce_sorted=False)
            output, hidden = self.rnns[i](output, hidden)
            if input_lengths is not None:
                output, _ = rnn.pad_packed_sequence(output, batch_first=True)
                if output.size(1) < slen:
                    # used for parallel
                    # padding = Variable(output.data.new(1, 1, 1).zero_())
                    padding = torch.zeros(
                        size=(1, 1, 1), dtype=output.type(),
                        device=output.device())
                    output = torch.cat(
                        [output,
                         padding.expand(
                             output.size(0),
                             slen - output.size(1),
                             output.size(2))
                         ], dim=1)
            if self.return_last:
                outputs.append(
                    hidden.permute(1, 0, 2).contiguous().view(bsz, -1))
            else:
                outputs.append(output)
        if self.concat:
            return torch.cat(outputs, dim=2)
        return outputs[-1]


class BiAttention(nn.Module):
    def __init__(self, input_size, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.input_linear = nn.Linear(input_size, 1, bias=False)
        self.memory_linear = nn.Linear(input_size, 1, bias=False)
        self.dot_scale = nn.Parameter(
            torch.zeros(size=(input_size,)).uniform_(1. / (input_size ** 0.5)),
            requires_grad=True)
        self.init_parameters()

    def init_parameters(self):
        return

    def forward(self, context, memory, mask):
        bsz, input_len = context.size(0), context.size(1)
        memory_len = memory.size(1)
        context = self.dropout(context)
        memory = self.dropout(memory)

        input_dot = self.input_linear(context)
        memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)
        cross_dot = torch.bmm(
            context * self.dot_scale,
            memory.permute(0, 2, 1).contiguous())
        att = input_dot + memory_dot + cross_dot
        att = att - 1e30 * (1 - mask[:, None])

        weight_one = F.softmax(att, dim=-1)
        output_one = torch.bmm(weight_one, memory)
        weight_two = (F.softmax(att.max(dim=-1)[0], dim=-1)
                      .view(bsz, 1, input_len))
        output_two = torch.bmm(weight_two, context)
        return torch.cat(
            [context, output_one, context * output_one,
             output_two * output_one],
            dim=-1)