import torch
from torch import nn
from torch import distributions as torchd

import models
import networks
import tools


class Random(nn.Module):

    def __init__(self, config):
        self._config = config

    def actor(self, feat):
        shape = feat.shape[:-1] + [self._config.num_actions]
        if self._config.actor_dist == 'onehot':
            return tools.OneHotDist(torch.zeros(shape))
        else:
            ones = torch.ones(shape)
            return tools.ContDist(torchd.uniform.Uniform(-ones, ones))

    def train(self, start, context):
        return None, {}