import torch
from torch import nn

# list of all nets: 
# ['l2_relu', 'l2_relu2', 'l2_Sigmoid', 'l2_Tanh', 'l2_SiLU', 'l2_sq', 'l2_cub', 'l2_elu', 'l2_softplus', 'l2_sigmoid2']

class l2_Sigmoid(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_Sigmoid, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)
        self.activate = nn.Sigmoid()

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = self.activate(tx)
        tx = self.fc2(tx)
        return tx
    
class l2_Tanh(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_Tanh, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)
        self.activate = nn.Tanh()

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = self.activate(tx)
        tx = self.fc2(tx)
        return tx
    
class l2_relu2(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_relu2, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)
        self.activate = nn.ReLU()

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = self.activate(tx)**2
        tx = self.fc2(tx)
        return tx
    
class l2_relu(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_relu, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)
        self.activate = nn.ReLU()

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = self.activate(tx)
        tx = self.fc2(tx)
        return tx
    
class l2_SiLU(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_SiLU, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)
        self.activate = nn.SiLU()

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = self.activate(tx)
        tx = self.fc2(tx)
        return tx

class l2_sq(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_sq, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = tx**2
        tx = self.fc2(tx)
        return tx
    
class l2_cub(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_cub, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = tx**3
        tx = self.fc2(tx)
        return tx
    

class l2_elu(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_elu, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)
        self.activate = nn.ELU()

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = self.activate(tx)
        tx = self.fc2(tx)
        return tx
    
class l2_softplus(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_softplus, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)
        self.activate = nn.Softplus()

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = self.activate(tx)
        tx = self.fc2(tx)
        return tx
    
class l2_sigmoid2(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l2_sigmoid2, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)
        self.activate = nn.Sigmoid()

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = self.activate(tx)**2
        tx = self.fc2(tx)
        return tx
    
class l3_relu2(nn.Module):
    def __init__(self, d, outdim, net_width):
        super(l3_relu2, self).__init__()
        self.fc1 = nn.Linear(d+1, net_width)
        self.fc2 = nn.Linear(net_width, net_width)
        self.fc3 = nn.Linear(net_width, outdim)
        self.activate = nn.ReLU()

    def forward(self, t, x):
        tx = torch.concat([t,x], dim=-1)
        tx = self.fc1(tx)
        tx = self.activate(tx)**2
        tx = self.fc2(tx)
        tx = self.activate(tx)**2
        tx = self.fc3(tx)
        return tx

class simple_tnet(nn.Module):
    def __init__(self, outdim, net_width):
        super(simple_tnet, self).__init__()
        self.fc1 = nn.Linear(1, net_width)
        self.fc2 = nn.Linear(net_width, outdim)
        self.activate = nn.ReLU()
        # self.activate = nn.Tanh()
        # self.activate = nn.Sigmoid()


    def forward(self, t): # t is N x 1 tensor
        temp = self.fc1(t)
        temp = self.activate(temp)
        temp = self.fc2(temp)
        return temp # N x outdim

class Symmetric_net(nn.Module):
    def __init__(self, outdim, net_width):
        super(Symmetric_net, self).__init__()
        self.fc1 = nn.Linear(1, net_width)
        self.fc2 = nn.Linear(net_width, outdim**2)
        self.outdim = outdim
        self.activate = nn.ReLU()
        # self.activate = nn.Tanh()
        # self.activate = nn.Sigmoid()
        

    def forward(self, t): # t is N x 1 tensor
        temp = self.fc1(t) # N x net_width
        temp = self.activate(temp) 
        temp = self.fc2(temp) # N x outdim^2
        temp = torch.reshape(temp, (-1, self.outdim, self.outdim)) # reshape to N x outdim x outdim
        return (temp + torch.transpose(temp, 1, 2))/2 # make it symmetric