import torch
import torch.nn as nn
import torch.nn.functional as F

class AdaptiveEmbedding(nn.Module):
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 
                 sample_softmax=False):
        super(AdaptiveEmbedding, self).__init__()

        self.n_token = n_token
        self.d_embed = d_embed

        self.cutoffs = cutoffs + [n_token]
        self.div_val = div_val
        self.d_proj = d_proj 

        self.emb_scale = d_proj ** 0.5

        self.cutoff_ends = [0] + self.cutoffs

        self.emb_layers = nn.ModuleList()
        self.n_emb_projs = 0
        # parameter list is not supported by DataParallel
        # move all parameters from ParameterList to module attributes
        # self.emb_projs = nn.ParameterList()
        if div_val == 1:
            self.emb_layers.append(
                nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
            )
            if d_proj != d_embed:
                setattr(self, 'emb_projs_0', nn.Parameter(torch.Tensor(d_proj, d_embed)))
                self.n_emb_projs += 1
                # self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
        else:
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
                d_emb_i = d_embed // (div_val ** i)
                self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
                # self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
                setattr(self, f'emb_projs_{i}', nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
                self.n_emb_projs += 1
            

    def forward(self, inp):
        if self.div_val == 1:
            embed = self.emb_layers[0](inp)
            if self.d_proj != self.d_embed:
                embed  = F.linear(embed, self.emb_projs[0])
        else:
            inp_flat = inp.contiguous().view(-1)
            emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], 
                                    dtype=torch.float, 
                                    device=inp.device)
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
                indices_i = mask_i.nonzero().squeeze()

                if indices_i.numel() == 0:
                    continue

                inp_i = inp_flat.index_select(0, indices_i) - l_idx
                emb_i = self.emb_layers[i](inp_i)
                # emb_i = F.linear(emb_i, self.emb_projs[i])
                emb_i = F.linear(emb_i, getattr(self, f'emb_projs_{i}'))

                emb_flat.index_copy_(0, indices_i, emb_i)

            embed = emb_flat.view(*inp.size(), self.d_proj)

        embed.mul_(self.emb_scale)

        return embed    