import torch
import torch.nn as nn
import torch.nn.functional as F


class Discriminator(nn.Module):
    def __init__(self, input_dim, output_dim, vectorize_input=False):
        """
        Initialize the discriminator.

        :param input_dim: Input dimension
        :param output_dim: Output dimension
        """
        super(Discriminator, self).__init__()
        print("input_dim: ", input_dim)
        self.vectorize_input = vectorize_input

        self.main = nn.Sequential(
            nn.Linear(input_dim, 100),  # 100
            nn.LeakyReLU(0.2, inplace=True),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(100, 100),
            #nn.ReLU(),
            nn.LeakyReLU(0.2, inplace=True),
            #nn.Sigmoid(),
            #nn.Softmax(),
            nn.Dropout(0.2),
            #nn.BatchNorm1d(100),
            nn.Linear(100, output_dim),
        )

    def forward(self, input_tensor):
        output_tensor = self.main(input_tensor)
        return output_tensor


class CombinedArchitectureSingle(nn.Module):
    """
    Class combining two equal neural network architectures.
    """
    def __init__(self, single_architecture, cost_function_v=1):
        super(CombinedArchitectureSingle, self).__init__()
        self.div_to_act_func = {
            "Jensen-Shannon": nn.Sigmoid(),
            "KL": nn.Identity(),
            4: nn.Softplus(),
            "SL": nn.Sigmoid(),
            6: nn.Softmax(),
            7: nn.Softplus(),  # nn.Softplus()
            8: nn.Softmax(),
            9: nn.Softplus(),
            10: nn.Sigmoid(), # nn.Softplus()
            11: nn.Softmax(),
            12: nn.Sigmoid()  
        }
        self.cost_function_version = cost_function_v
        self.single_architecture = single_architecture
        self.final_activation = self.div_to_act_func[cost_function_v]

    def forward(self, input_tensor_1):
        intermediate_1 = self.single_architecture(input_tensor_1)
        output_tensor_1 = self.final_activation(intermediate_1)
        return output_tensor_1
