import torch
from torch import nn

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, dropout=0.0):
        super(MLP, self).__init__()
        # layers = [nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(p=dropout)]
        layers = [nn.Linear(input_dim, hidden_dim), nn.SiLU()]
        for i in range(num_layers - 1):
            layers += [nn.Linear(hidden_dim, hidden_dim), nn.SiLU(), nn.Dropout(p=dropout/(i+1))]
            # layers += [nn.Linear(hidden_dim, hidden_dim), nn.SiLU(), nn.Dropout(p=dropout/(3*i+1))]
        layers.append(nn.Linear(hidden_dim, output_dim))
        layers.append(nn.Sigmoid())
        self.mlp = nn.Sequential(*layers)
        # Initialize weights when the model is created
        # for m in self.mlp:
        #         if isinstance(m, nn.Linear):
        #             nn.init.xavier_uniform_(m.weight)
        #             if m.bias is not None:
        #                 nn.init.zeros_(m.bias)
               
    def forward(self, x):
        return self.mlp(x)

# class MLP(nn.Module):
#     def __init__(
#         self,
#         input_dim: int,
#         hidden_dim: int,
#         output_dim: int,
#         num_layers: int = 1,
#         dropout: float = 0.0,
#         init_log_sigma: float = -10.0,
#         max_noise: float = 0.05,
#     ):
#         super().__init__()
#         # build all hidden layers (no final Sigmoid here)
#         layers = [nn.Linear(input_dim, hidden_dim), nn.SiLU()]
#         for i in range(num_layers - 1):
#             layers += [
#                 nn.Linear(hidden_dim, hidden_dim),
#                 nn.SiLU(),
#                 nn.Dropout(p=dropout / (i + 1)),
#             ]
#         # final linear → outputs μ
#         layers.append(nn.Linear(hidden_dim, output_dim))

#         self.mlp = nn.Sequential(*layers)

#         # noise parameters
#         # logσ is learnable, initialized so σ ≈ exp(-3)=0.05
#         self.log_sigma = nn.Parameter(torch.full((output_dim,), init_log_sigma))
#         # clamp noise to this range
#         self.max_noise = max_noise

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         # 1) compute base prediction μ
#         mu = self.mlp(x)                            # shape (B, output_dim)

#         # 2) sample in-graph noise ε ~ N(0,1)
#         sigma = self.log_sigma.exp()                # ensure σ>0
#         eps   = torch.randn_like(mu)

#         # 3) scale and clamp
#         noise = (sigma * eps).clamp(-self.max_noise, self.max_noise)

#         # 4) add to μ then squash
#         return torch.sigmoid(mu + noise)
    
# class MLP(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, dropout=0.0):
#         super(MLP, self).__init__()
        
#         # Build hidden layers
#         layers = [nn.Linear(input_dim, hidden_dim), nn.SiLU()]
#         for _ in range(num_layers - 1):
#             layers += [nn.Linear(hidden_dim, hidden_dim), nn.Dropout(p=dropout), nn.SiLU()]
        
#         # Mean and log variance networks
#         self.hidden_layers = nn.Sequential(*layers)
#         self.mean_head = nn.Linear(hidden_dim, output_dim)
#         self.logvar_head = nn.Linear(hidden_dim, output_dim)
        
#         # For deterministic output with sigmoid (used for evaluation)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, x, sample=True, temperature=1.0):
#         """
#         Args:
#             x: Input tensor
#             sample: Whether to sample from distribution or return mean
#             temperature: Higher values = more exploration (scales the std)
#         """
#         features = self.hidden_layers(x)
        
#         # Get distribution parameters
#         means = self.mean_head(features)
#         logvars = self.logvar_head(features)
        
#         # Clamp logvars for stability
#         logvars = torch.clamp(logvars, -10, 2)
        
#         # Return deterministic prediction if not sampling
#         if not sample:
#             return self.sigmoid(means)
        
#         # Sample from the distribution with temperature scaling
#         std = torch.exp(0.5 * logvars) * temperature
#         eps = torch.randn_like(means)
#         samples = means + eps * std
        
#         # Apply sigmoid to keep in [0,1] range
#         return self.sigmoid(samples)

    
class FNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, num_freq = 200):
        """
        Initialize the FNN model with Fourier features and an MLP.

        Parameters:
            input_dim (int): Dimensionality of the input data.
            hidden_dim (int): Number of hidden units in the MLP.
            output_dim (int): Dimensionality of the output.
            num_layers (int, optional): Number of layers in the MLP. Default is 1.
            n_freq (int, optional): Number of Fourier frequencies for feature mapping. Default is 2.
            init_freq (list of float, optional): Initial frequency values for Fourier features. Default is [1.0, 2.0].
        """
        super(FNN, self).__init__()
        # B = torch.tensor(init_freq).repeat(input_dim, 1)
        # self.B = nn.Parameter(B.clone().detach().requires_grad_(True))  # (input_dim, n_freq) 
        B = 10.0
        self.B = nn.Parameter(B * torch.ones(input_dim, num_freq))

        # Fourier features mapping
        self.fourier = nn.Linear(input_dim, num_freq, bias=False)
        # MLP
        # layers = [nn.Linear(2* n_freq, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU()]
        layers = [nn.Linear(2*num_freq, hidden_dim), nn.SiLU()]
        for _ in range(num_layers - 1):
            layers += [nn.Linear(hidden_dim, hidden_dim), nn.SiLU()]
        layers.append(nn.Linear(hidden_dim, output_dim))
        layers.append(nn.Sigmoid())
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, x):
        x_proj = torch.matmul(x, self.B)
        x_fourier = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
        y = self.mlp(x_fourier)
        return y 
    
# class FNN_FS(FNN):
#     def forward(self, x):
#         x_proj = torch.matmul(x, self.B)
#         x_fourier = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
#         y = self.mlp(x_fourier)


#         return y


