{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2f524c57-4b5e-45a8-9e65-2e93128adb3e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import numpy\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "\n",
    "M = 2\n",
    "d = 1024\n",
    "n_ = 100\n",
    "dh = 512\n",
    "dv = 512\n",
    "cp = 4\n",
    "class TF(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.q = nn.Linear(d, dh, bias=False)\n",
    "        self.k = nn.Linear(d, dh, bias=False)\n",
    "        self.v = nn.Linear(d, dv, bias=False)\n",
    "        self.fc = nn.Linear(dv, 1, bias=False)\n",
    "        self.fc.requires_grad_(False)\n",
    "        self.q.weight.data /= 16\n",
    "        self.k.weight.data /= 16\n",
    "        self.v.weight.data /= 16\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        q = self.q(x)\n",
    "        k = self.k(x)\n",
    "        v = self.v(x)\n",
    "        qk = torch.matmul(q, k.transpose(1, 2))\n",
    "        attn = qk.softmax(dim=2)\n",
    "        attn = torch.sum(attn, dim=1).unsqueeze(1)\n",
    "        attn /= 16 \n",
    "        z = torch.matmul(attn, v).squeeze(1)\n",
    "        return self.fc(z)\n",
    "\n",
    "def make_mu1(mu):\n",
    "    mu1 = numpy.zeros(d)\n",
    "    mu1[0] = mu\n",
    "    return mu1\n",
    "\n",
    "def make_mu2(mu):\n",
    "    mu2 = numpy.zeros(d)\n",
    "    mu2[1] = mu\n",
    "    return mu2\n",
    "\n",
    "def make_noise(strength):\n",
    "    return numpy.random.normal(0, strength, size=d)\n",
    "\n",
    "def get_test_loss(n, mu):\n",
    "    D = []\n",
    "    D_Y = []\n",
    "    D_ = []\n",
    "    D_Y_ = []\n",
    "\n",
    "    mu1 = make_mu1(mu)\n",
    "    mu2 = make_mu2(mu)\n",
    "\n",
    "    D_mu = []\n",
    "    D_mu_ = []\n",
    "\n",
    "    for i in range(int(n / 2)):\n",
    "        X = mu1.copy().reshape(1, d)\n",
    "        X = numpy.concatenate((X, (make_noise(cp)).reshape(1, d)), 0)\n",
    "        for j in range(M - 2):\n",
    "            X = numpy.concatenate((X, (make_noise(0.2)).reshape(1, d)), 0)\n",
    "        D.append(X)\n",
    "        D_Y.append([1.]) if np.random.rand() > 0.001 else D_Y.append([-1.])\n",
    "        D_mu.append(mu1.reshape(1, d))\n",
    "        X = mu2.copy().reshape(1, d)\n",
    "        X = numpy.concatenate((X, (make_noise(cp)).reshape(1, d)), 0)\n",
    "        for j in range(M - 2):\n",
    "            X = numpy.concatenate((X, (make_noise(0.2)).reshape(1, d)), 0)\n",
    "        D.append(X)\n",
    "        #标签反转\n",
    "        D_Y.append([-1.]) if np.random.rand() > 0.001 else D_Y.append([1.])\n",
    "        D_mu.append(mu2.reshape(1, d))\n",
    "\n",
    "    D = torch.tensor(D, dtype=torch.float32).cuda()\n",
    "    D_Y = torch.tensor(D_Y).cuda()\n",
    "\n",
    "    for i in range(int(n_ / 2)):\n",
    "        X = mu1.copy().reshape(1, d)\n",
    "        X = numpy.concatenate((X, (make_noise(cp)).reshape(1, d)), 0)\n",
    "        for j in range(M - 2):\n",
    "            X = numpy.concatenate((X, (make_noise(0.2)).reshape(1, d)), 0)\n",
    "        D_.append(X)\n",
    "        D_Y_.append([1.])\n",
    "        D_mu_.append(mu1.reshape(1, d))\n",
    "        X = mu2.copy().reshape(1, d)\n",
    "        X = numpy.concatenate((X, (make_noise(cp)).reshape(1, d)), 0)\n",
    "        for j in range(M - 2):\n",
    "            X = numpy.concatenate((X, (make_noise(0.2)).reshape(1, d)), 0)\n",
    "        D_.append(X)\n",
    "        D_Y_.append([-1.])\n",
    "        D_mu_.append(mu2.reshape(1, d))\n",
    "\n",
    "    D_ = torch.tensor(D_, dtype=torch.float32).cuda()\n",
    "    D_Y_ = torch.tensor(D_Y_).cuda()\n",
    "\n",
    "    model = TF().cuda()\n",
    "\n",
    "    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0)\n",
    "    loss_fn = nn.SoftMarginLoss().cuda()\n",
    "    EPOCHS = 1000\n",
    "\n",
    "    for epoch in range(1, EPOCHS + 1):\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "        output = model(D)\n",
    "\n",
    "        training_loss = loss_fn(output, D_Y)\n",
    "        training_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        model.eval()\n",
    "        output = model(D_)\n",
    "        test_loss = loss_fn(output, D_Y_)\n",
    "\n",
    "        \n",
    "    return test_loss\n",
    "\n",
    "size = 5\n",
    "n_scale = 2\n",
    "mu_scale = 1\n",
    "matrix = np.zeros((size, 100))\n",
    "for n in range(size):\n",
    "    print(n)\n",
    "    for mu in range(100):\n",
    "        matrix[n, mu] = get_test_loss(int((mu + 10) / 10) * n_scale, 16)\n",
    "np.savetxt('0.001.npy', matrix)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a67ee89-e8db-45b1-a1ed-cfc63fc4f786",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
