from math import sqrt, log
import warnings

import torch

from .base import FeatureMap

class PositiveRandomFeatures(FeatureMap):
    """Positive Random Features.

    Arguments
    ---------
        query_dimensions: int, The input query dimensions in order to sample
                          the noise matrix
        n_dims: int, The size of the feature map (default: query_dimensions)
        redraw: int, Redraw the random matrix every 'redraw' times
                (default: 1)
        deterministic_eval: bool, Only redraw the random matrix during training
                            (default: True)
    """
    def __init__(self, query_dimensions, n_dims=None, redraw=1, deterministic_eval=True):
        super(PositiveRandomFeatures, self).__init__(query_dimensions)

        self.n_dims = n_dims or query_dimensions
        self.query_dimensions = query_dimensions
        self.redraw = redraw
        self.deterministic_eval = deterministic_eval

        # Make a buffer for storing the sampled omega
        self.register_buffer(
            "omega",
            torch.zeros(self.query_dimensions, self.n_dims)
        )
        self._calls = -1

    def new_feature_map(self, device):
        # If we are not training skip the generation of a new feature map
        if self.deterministic_eval and not self.training:
            return

        # Only redraw the new feature map every self.redraw times
        self._calls += 1
        if (self._calls % self.redraw) != 0:
            return

        omega = torch.zeros(
            self.query_dimensions,
            self.n_dims,
            device=device
        )
        omega.normal_()
        self.register_buffer("omega", omega)

    def forward(self, x):
        u = x.unsqueeze(-2).matmul(self.omega).squeeze(-2)
        phi = torch.exp(u)
        return phi