import torch
from torch.nn import Module

class ConstantMixtureWeights(Module):
    def __init__(self, num_components, device=torch.device('cpu')):
        super(ConstantMixtureWeights, self).__init__()
        self.device=device
        self.num_components = num_components
        self.weights = torch.nn.Parameter(torch.zeros([self.num_components]).to(self.device))
        self.register_parameter('weights', self.weights)


    def condition(self, x):
        pass

    def forward(self):
        return self.weights
        
class ReluMixtureWeights(Module):
    def __init__(self, num_components, dimensions, width, depth, device=torch.device('cpu')):
        super(ReluMixtureWeights, self).__init__()
        self.device=device
        self.num_components = num_components
        self.dimensions = dimensions
        self.depth = depth
        self.network = torch.nn.ModuleList([torch.nn.Linear(dimensions, width).to(self.device)] + [torch.nn.Linear(width, width).to(self.device) for i in range(0, depth)] + [torch.nn.Linear(width, num_components).to(self.device)])
        self.conditioner = torch.zeros([self.dimensions]).to(self.device)


    def condition(self, x):
        self.conditioner = x

    def forward(self):
        out = self.network[0](self.conditioner)
        out = torch.nn.functional.leaky_relu(out)
        for i in range(0, self.depth):
            out = self.network[1 + i](out)
            out = torch.nn.functional.leaky_relu(out)
        return self.network[-1](out)

