import random

import torch
from torch import nn


class ScaledDecoder(nn.Module):
    def __init__(self, ninp, nhid, nout):
        super().__init__()
        self.linear = nn.Linear(ninp, nhid)
        self.linear1 = nn.Linear(nhid, nout)
        self.linear2 = nn.Linear(nhid, 10)

    def forward(self, x):
        # return torch.cat([self.linear1(x), self.linear2(x)], -1)
        x = self.linear(x)
        x = nn.GELU()(x)
        temps = self.linear2(x).softmax(-1) @ torch.tensor(
            [1.0, 1.4, 1.7, 2.0, 5.0, 10.0, 20.0, 40.0, 80.0, 160.0], device=x.device
        )
        if random.random() > 0.99:
            print(temps.shape, temps[:, :2])
        return self.linear1(x) / temps.unsqueeze(-1)


class FixedScaledDecoder(nn.Module):
    def __init__(self, ninp, nhid, nout):
        super().__init__()
        self.mapper = nn.Sequential(
            nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout)
        )
        self.T = nn.Parameter(torch.ones(10000) / 10000)

    def forward(self, x):
        return self.mapper(x) / self.T.sum()
