import torch
from torch.nn import Module

       
class ReluConditioner(Module):
    def __init__(self, dimensions, width, depth, device=torch.device('cpu')):
        super(ReluConditioner, self).__init__()
        self.device=device
        self.dimensions = dimensions
        self.depth = depth
        network_list = [torch.nn.Linear(dimensions, width).to(self.device)]
        for i in range(0, depth):
            network_list += [torch.nn.LeakyReLU().to(self.device)]
            network_list += [torch.nn.Linear(width, width).to(self.device)]
        network_list += [torch.nn.LeakyReLU().to(self.device)]
        network_list +=[torch.nn.Linear(width, dimensions).to(self.device)]
            
        self.network = torch.nn.Sequential(*network_list)
        self.conditioner = torch.zeros([self.dimensions]).to(self.device)


    def condition(self, x):
        self.conditioner = x

    def forward(self):
        return self.network(self.conditioner)

class ReluMeanScale(Module):
    def __init__(self, dimensions, mean_parameters, scale_parameters, device=torch.device('cpu')):
        super(ReluMeanScale, self).__init__()
        self.device=device
        self.dimensions=dimensions
        self.mean_parameters=mean_parameters
        self.scale_parameters=scale_parameters
        self.mean_network = ReluConditioner(dimensions, mean_parameters['width'], mean_parameters['depth'], device=device)
        self.scale_network = ReluConditioner(dimensions, scale_parameters['width'], scale_parameters['depth'], device=device)
        self.conditioner = torch.zeros([self.dimensions]).to(self.device)
        self.mean_offset = torch.zeros([self.dimensions]).to(self.device)
        self.scale_offset = torch.ones([self.dimensions]).to(self.device)


    def condition(self, x):
        self.conditioner = x
        self.mean_network.condition(x)
        self.scale_network.condition(x)

    def forward(self):
        return self.mean_network() + self.mean_offset, torch.exp(torch.clamp(self.scale_network(), min=-20.0, max=20.0))*self.scale_offset

    def set_offset(self, mean, scale):
        self.mean_offset = mean.to(self.device)
        self.scale_offset = scale.to(self.device)
