import numpy as np

import torch
import torch.nn as nn

class scaled_relu(nn.Module):
    
    def __init__(self, scale):
        super().__init__()
        #self.flatten = nn.Flatten()

        self.act = nn.ReLU()
        self.scale = scale

    def forward(self, X):
        return self.scale * self.act(X)
       
class scaled_sigmoid(nn.Module):
    
    def __init__(self, scale):
        super().__init__()
        #self.flatten = nn.Flatten()

        self.act = nn.Sigmoid()
        self.scale = scale

    def forward(self, X):
        return self.scale * self.act(X)
              
class scaled_tanh(nn.Module):
    
    def __init__(self, scale):
        super().__init__()
        #self.flatten = nn.Flatten()

        self.act = torch.tanh
        self.scale = scale

    def forward(self, X):
        return self.scale * self.act(X)
       
class scaled_swish(nn.Module):
    
    def __init__(self, scale):
        super().__init__()
        #self.flatten = nn.Flatten()

        self.act = nn.SiLU()
        self.scale = scale

    def forward(self, X):
        return self.scale * self.act(X)

class scaled_mish(nn.Module):
    
    def __init__(self, scale):
        super().__init__()
        #self.flatten = nn.Flatten()

        self.act = nn.Mish()
        self.scale = scale

    def forward(self, X):
        return self.scale * self.act(X)

class scaled_sin(nn.Module):
    
    def __init__(self, scale):
        super().__init__()
        #self.flatten = nn.Flatten()

        self.scale = scale

    def forward(self, X):
        return self.scale * torch.sin(torch.pi*X)

class scaled_cos(nn.Module):
    
    def __init__(self, scale):
        super().__init__()
        #self.flatten = nn.Flatten()

        self.scale = scale

    def forward(self, X):
        return self.scale * torch.cos(torch.pi*X)
