from spaghettini import quick_register

import numpy as np
import torch
from torch import nn
from torch.nn import Softplus, ReLU

from src.utils.gumbel_softmax import gumbel_softmax

RNG = torch.Generator()
RNG = RNG.manual_seed(0)  # TODO: This is not ideal - try to get this in the discretizer.


@quick_register
class OneHotDiscretizer(nn.Module):
    def __init__(self, feat_to_logit_projector, num_heads=1, temperature=1.):
        super().__init__()
        self.feat_to_logit_projector = feat_to_logit_projector
        self.num_heads = num_heads
        self.temperature = temperature

    def forward(self, inputs):
        bs = inputs.shape[0]

        # Project to logits.
        logits = self.feat_to_logit_projector(inputs)

        # Reshape into to distinguish between heads.
        logit_heads = logits.view(bs, self.num_heads, -1)

        # Discretize logits in each head and return.
        oh_tokens = gumbel_softmax(logits=logit_heads, temperature=self.temperature, rng=RNG)
        return oh_tokens.view((bs, -1))
