{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "70853b2e-d56c-4d29-89b6-e6b3039e25fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-inf) tensor(nan)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.optim as optim\n",
    "import math\n",
    "from what_adam_cant_do_thorough \n",
    "\n",
    "func_range = 100\n",
    "\n",
    "input_dim = 100\n",
    "x = 2*func_range* torch.rand(10000, input_dim) - func_range  \n",
    "\n",
    "\n",
    "function = schaffer_f7\n",
    "\n",
    "y =  function(x)\n",
    "\n",
    "mean = y.mean()\n",
    "std = y.std()\n",
    "\n",
    "print(mean, std)\n",
    "\n",
    "std = std\n",
    "\n",
    "target_function = lambda x: (function(x) - mean)/std\n",
    "\n",
    "def generate_batch(batch_size=1000, input_dim=input_dim):\n",
    "    x = 2*func_range* torch.rand(batch_size, input_dim) - func_range  \n",
    "    y = target_function(x)\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "608287e5-d62a-406b-9a7b-d5feab1a4a89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ackley function\n",
    "def ackley(x, a=20, b=0.2, c=2 * math.pi):\n",
    "    D = x.shape[1]\n",
    "    sum_sq_term = torch.sum(x ** 2, dim=1) / D\n",
    "    cos_term = torch.sum(torch.cos(c * x), dim=1) / D\n",
    "    result = -a * torch.exp(-b * torch.sqrt(sum_sq_term)) - torch.exp(cos_term) + a + math.e\n",
    "    return result\n",
    "\n",
    "def alpine_1(x):\n",
    "    return torch.sum(torch.abs(x * torch.sin(x) + 0.1 * x), dim=1)\n",
    "\n",
    "def rastrigin(x): \n",
    "    n = x.shape[1]\n",
    "    return 10 * n + torch.sum(x**2 - 10 * torch.cos(2 * torch.pi * x), dim=1)\n",
    "\n",
    "def xin_she_yang_1(x): \n",
    "    u = torch.sum(torch.abs(x), dim=-1)\n",
    "    \n",
    "    v = torch.exp(-torch.sum(torch.sin(x**2), dim=-1))\n",
    "\n",
    "    return u * v\n",
    "\n",
    "def schaffer_f7(x):\n",
    "    term1 = torch.sqrt(x[:, :-1]**2 + x[:, 1:]**2)  \n",
    "    wraparound_distance = torch.sqrt(x[:, -1]**2 + x[:, 0]**2)  \n",
    "    term1 = torch.cat((term1, wraparound_distance.unsqueeze(1)), dim=1)\n",
    "    return torch.sum(term1**0.5 * (1 + torch.sin(50 * term1**0.2)**2), dim=1)/x.shape[1]\n",
    "\n",
    "\n",
    "def expanded_schaffer_f6(x): \n",
    "    term1 = torch.sqrt(x[:, :-1]**2 + x[:, 1:]**2)\n",
    "    term1 = torch.cat([term1, torch.sqrt(x[:,-1]**2 + x[:,0]**2).reshape(-1,1)],dim=1)\n",
    "    return torch.sum(0.5 + (torch.sin(term1)**2 - 0.5) / (1 + 0.001 * term1**2)**2, dim=1)\n",
    "\n",
    "def xin_she_yang_3(x):\n",
    "    bracket_1 = torch.exp(-torch.sum((x/15)**10, dim=1)) - 2*torch.exp(-torch.sum(x**2,dim=1))\n",
    "    ans =  bracket_1 * torch.prod(torch.cos(x)**2, dim=1)\n",
    "    return (1 + ans)\n",
    "\n",
    "def xin_she_yang_5(x):\n",
    "    ans = (torch.sum(torch.sin(x)**2, dim=1) - torch.exp(torch.sum(x**2,dim=1)))*(torch.exp(-torch.sum(torch.sin(x.abs()**0.5)**2,dim=1)))\n",
    "    return (1 + ans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "493e0f4f-7054-4269-a03b-6ef5e33f8ee0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# Define the fully connected neural network\n",
    "class FCNN(nn.Module):\n",
    "    def __init__(self, input_dim=100, output_dim=1):\n",
    "        super(FCNN, self).__init__()\n",
    "        # Define layers\n",
    "        self.fc1 = nn.Linear(input_dim, 1000)\n",
    "        self.fc2 = nn.Linear(1000, 1000)\n",
    "        self.fc3 = nn.Linear(1000, 1000)\n",
    "        self.fc4 = nn.Linear(1000, 1000)\n",
    "        self.fc5 = nn.Linear(1000, 1000)\n",
    "        self.fc6 = nn.Linear(1000, output_dim)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        \n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = F.relu(self.fc3(x))\n",
    "        x = F.relu(self.fc4(x))\n",
    "        x = F.relu(self.fc5(x))\n",
    "        x = self.fc6(x)  # No activation for final output (regression task)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76bf727f-2257-4ac9-b9e7-3a01b3b8e69b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training loop\n",
    "def train_model(model, optimizer, criterion, num_epochs=10, batch_size=1000):\n",
    "    for epoch in range(num_epochs):\n",
    "        # Generate a batch of data\n",
    "        x_batch, y_batch = generate_batch(batch_size)\n",
    "        \n",
    "        # Forward pass\n",
    "        predictions = model(x_batch.reshape(-1,input_dim))\n",
    "        assert predictions.squeeze().size() == y_batch.squeeze().size()\n",
    "        loss = criterion(predictions.squeeze(), y_batch.squeeze())\n",
    "\n",
    "        \n",
    "        # Backward pass and optimization\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        \n",
    "        \n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        # Print the loss for every epoch\n",
    "        if epoch % 10 == 0:\n",
    "            print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}\") #, Loss oscillations {loss_oscillations.item():.4f} Loss convex {loss_convex.item():.4f}\")\n",
    "            print(predictions.std(), 'std')\n",
    "            \n",
    "# Initialize the model, optimizer, and loss function\n",
    "\n",
    "output_dim = 1\n",
    "model = FCNN(input_dim, output_dim)\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "criterion = nn.MSELoss()\n",
    "\n",
    "# Train the model\n",
    "train_model(model, optimizer, criterion, num_epochs=1000, batch_size=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4a05ddd-09d3-47ba-a4da-a6597116798e",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = generate_batch(batch_size=1000)\n",
    "model.eval()\n",
    "y_pred = model(x.reshape(-1,input_dim))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7260f5c7-d285-4d92-be54-0b882df1d374",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion(torch.ones(y.squeeze().shape)*y.mean(), y.squeeze())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "945c796f-384d-49fd-9f3e-e085f8a54645",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion(y_pred.squeeze(), y.squeeze()) #if this is less than the above quantity, we consider the function learnable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c3c919-4950-4e92-890b-9ca3a6959c99",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
