import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class FCNN(nn.Module):
    def __init__(self, n_dim, hidden_sizes=[1024, 512, 256]):
        super().__init__()
        layers = []
        in_dim = n_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(in_dim, h))
            layers.append(nn.Tanh())
            in_dim = h
        layers.append(nn.Linear(in_dim, 1))
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)
    
class ackley():
    def __init__(self, n_dim=2, device='cpu'):
        self.n_dim = n_dim
        self.device = device       
    
    def f_func(self, x):
        """
        x: Tensor of shape [B, n_dim]
        returns: Tensor of shape [B]
        """
        # constants
        a = 20.0
        b = 0.2
        c = 2 * np.pi
        n_dim = x.shape[1]
        eps    = 1e-12

        # sum of squares term
        sum_sq = torch.sum(x**2, dim=1)              # [B]
        term1  = -a * torch.exp(-b * torch.sqrt(sum_sq / n_dim))

        # sum of cosines term
        sum_cos = torch.sum(torch.cos(c * x), dim=1) # [B]
        term2   = -torch.exp(sum_cos / n_dim)

        return term1 + term2 + a + np.e   
    
    def grad_f(self, x):
        """
        x: Tensor [B, n_dim], requires_grad=False
        returns: Tensor [B, n_dim] ∇f at x
        """
        x = x.clone().detach().requires_grad_(True)
        fx = self.f_func(x).sum()   # sum so grad returns [B,n_dim]
        grad = torch.autograd.grad(fx, x)[0]
        return grad.detach()