import torch.nn as nn
import torch.nn.functional as F


class Logistic_Reg(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Logistic_Reg, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, output_dim),
        )

    def forward(self, x):
        return self.model(x)

class FCN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256, dropout=0.0):
        super(FCN, self).__init__()
        self.model = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.model(x)


class SAM_FCN_OP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256, dropout=0.0):
        super(SAM_FCN_OP, self).__init__()
        self.linear_layer1_w = nn.Linear(input_dim, hidden_dim)
        self.act1_w = nn.LeakyReLU(0.1)
        self.linear_layer2_w = nn.Linear(hidden_dim, hidden_dim)
        self.act2_w = nn.LeakyReLU(0.1)
        self.linear_layer3_w = nn.Linear(hidden_dim, hidden_dim)
        self.act3_w = nn.LeakyReLU(0.1)
        self.linear_layer4_w =nn.Linear(hidden_dim, output_dim)

        self.input_dropout_w_eps = nn.Dropout(dropout)
        self.linear_layer1_w_eps = nn.Linear(input_dim, hidden_dim)
        self.linear_layer2_w_eps = nn.Linear(hidden_dim, hidden_dim)
        self.linear_layer3_w_eps = nn.Linear(hidden_dim, hidden_dim)
        self.linear_layer4_w_eps =nn.Linear(hidden_dim, output_dim)

    def reset_weights_eps(self):
        nn.init.zeros_(self.linear_layer1_w_eps.weight)
        # nn.init.kaiming_normal_(self.linear_layer1_w_eps.bias)
        nn.init.zeros_(self.linear_layer2_w_eps.weight)
        # nn.init.kaiming_normal_(self.linear_layer2_w_eps.bias)
        nn.init.zeros_(self.linear_layer3_w_eps.weight)
        # nn.init.kaiming_normal_(self.linear_layer3_w_eps.bias)
        nn.init.zeros_(self.linear_layer4_w_eps.weight)

    def forward(self, x):
        x = F.leaky_relu(self.linear_layer1_w(x), 0.1)
        x = F.leaky_relu(self.linear_layer2_w(x), 0.1)
        x = F.leaky_relu(self.linear_layer3_w(x), 0.1)
        x = self.linear_layer4_w(x)
        return x

    def forward_w_eps(self, x):
        x = F.leaky_relu(self.linear_layer1_w(x) + self.linear_layer1_w_eps(x), 0.1)
        x = F.leaky_relu(self.linear_layer2_w(x) + self.linear_layer2_w_eps(x), 0.1)
        x = F.leaky_relu(self.linear_layer3_w(x) + self.linear_layer3_w_eps(x), 0.1)
        x = self.linear_layer4_w(x) + self.linear_layer4_w_eps(x)
        return x








