import math
import torch
import torch.nn.functional as F
from torch import nn

class Tanh(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.tanh(x)

class TanhGPN(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 1.4674 * torch.tanh(x) + 0.3886

class ReLU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return F.relu(x)

class ReLUGPN(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 1.4142 * F.relu(x)

class LeakyReLU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return F.leaky_relu(x, 0.01)

class LeakyReLUGPN(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 1.4141 * F.leaky_relu(x, 0.01)

class ELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return F.elu(x)

class ELUGPN(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 1.2234 * F.elu(x) + 0.0742

class SELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return F.selu(x)

class SELUGPN(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 0.9660 * F.selu(x) + 0.2585

class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x * torch.sigmoid(1.702*x)

class GELUGPN(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 1.4915 * x * torch.sigmoid(1.702*x) - 0.9097


class OrthLinear(nn.Module):

    def __init__(self, in_channels, constraint=False, with_bn=False):
        super().__init__()

        self.in_channels = in_channels
        self.constraint = constraint

        w = torch.randn(in_channels, in_channels)
        u, _, v = torch.svd(w, compute_uv=True)
        w = u @ v.t()

        self.weight = nn.Parameter(w)

        if constraint:
            self.bias = False
        else:
            self.bias = nn.Parameter(torch.zeros(1, in_channels))

        if with_bn:
            self.bn = nn.BatchNorm1d(in_channels)
        else:
            self.bn = False

    def forward(self, x):

        if self.training:

            if self.constraint:
                self.W_bar = self.weight / self.weight.pow(2).sum(0, keepdim=True).sqrt()
            else:
                self.W_bar = self.weight

        y = x @ self.W_bar

        if not self.constraint:
            y = y + self.bias

        if not self.bn is False:
            y = self.bn(y)

        return y
