{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ee7061-8f7c-4d64-bbe0-15561238b4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "device = torch.device(\"cuda:0\")\n",
    "import numpy as np\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef46e355-ec62-4cd2-902f-091fb79fc218",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 2\n",
    "E = 10\n",
    "nu = 0.3\n",
    "mu = 0.5 * E / (1 + nu)\n",
    "lam = E * nu / ((1 + nu) * (1 - 2 * nu))\n",
    "print(mu)\n",
    "print(lam)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d10aef3-41a5-4d73-ad16-8df6f26bbac9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analytical_projection_sigma(sigma, yield_strain):\n",
    "    epsilon = torch.log(sigma)\n",
    "    trace_epsilon = torch.sum(epsilon, dim=-1, keepdim=True)\n",
    "    epsilon_hat = epsilon - trace_epsilon / dim\n",
    "    s_hat = torch.nn.functional.normalize(epsilon_hat, dim=-1, eps=1e-6)\n",
    "    epsilon_hat_norm = torch.linalg.norm(epsilon_hat, dim=-1, keepdim=True)\n",
    "    delta_gamma = epsilon_hat_norm - yield_strain\n",
    "    H = epsilon - delta_gamma * s_hat\n",
    "    Z = torch.exp(H)\n",
    "    return torch.where(delta_gamma > 0, Z, sigma)\n",
    "\n",
    "def analytical_projection(F, yield_strain):\n",
    "    U, sigma, Vh = torch.linalg.svd(F)\n",
    "    Z = analytical_projection_sigma(sigma, yield_strain)\n",
    "    return U @ torch.diag_embed(Z) @ Vh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88912318-a3b0-433e-84df-ca18012b2f86",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_angle(batch_size, device='cpu'):\n",
    "    return torch.from_numpy(np.random.uniform(size=[batch_size, 1], low=-np.pi, high=np.pi)).float().to(device)\n",
    "\n",
    "def random_2d_rotation(batch_size, device='cpu'):\n",
    "    theta = random_angle(batch_size, device)\n",
    "    R0 = torch.cat([torch.cos(theta)[:, :, None], torch.sin(theta)[:, :, None]], dim=-1)\n",
    "    R1 = torch.cat([-torch.sin(theta)[:, :, None], torch.cos(theta)[:, :, None]], dim=-1)\n",
    "    R = torch.cat([R0, R1], dim=-2)\n",
    "    return R\n",
    "\n",
    "def random_sigma(batch_size, device='cpu'):\n",
    "    return torch.exp(torch.from_numpy(np.random.uniform(size=[batch_size, dim], low=-1, high=1)).float().to(device))\n",
    "\n",
    "def random_yield_strain(batch_size, device='cpu'):\n",
    "    return torch.from_numpy(np.random.uniform(size=[batch_size, 1], low=0, high=1)).float().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52e5f128-63c4-4e77-acd0-56d86b0e9d84",
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_stress(F, yield_strain):\n",
    "    U, sigma, Vh = torch.linalg.svd(F)\n",
    "    Z = analytical_projection_sigma(sigma, yield_strain)\n",
    "    epsilon = torch.log(Z)\n",
    "    trace_epsilon = torch.sum(epsilon, dim=-1, keepdim=True)\n",
    "    tau = 2 * mu * epsilon + lam * trace_epsilon\n",
    "    grad = tau / sigma\n",
    "    stress = U @ torch.diag_embed(grad) @ Vh\n",
    "    return stress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b2f7928-3f41-4a7c-95b7-8a7440f7e067",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Activation(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x * torch.sigmoid(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7f0481c-6834-480c-95ea-1bcff4bafef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eye_like(tensor):\n",
    "    return torch.tile(torch.eye(dim, device=tensor.device), (tensor.shape[0], 1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "883e22a8-164c-4b47-ae3b-1401821a2cad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hidden_dim = 32\n",
    "class PsiNet(nn.Module):\n",
    "    training = True\n",
    "    def __init__(self):\n",
    "        super(PsiNet, self).__init__()\n",
    "        self.layers = nn.ModuleList([\n",
    "            nn.Linear(dim * dim * 2 + 1, hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,1)\n",
    "        ])\n",
    "    def forward(self, F, F_flag, yield_strain):\n",
    "        input = torch.cat([F.flatten(1), F_flag.flatten(1), yield_strain],axis=-1)\n",
    "        out = input\n",
    "        for i, l in enumerate(self.layers):\n",
    "            out = l(out)\n",
    "        return out\n",
    "    \n",
    "def training_output(model, F_, F_flag, yield_strain):\n",
    "    F_flag.requires_grad_()\n",
    "    F = F_.detach().clone().requires_grad_()\n",
    "    F_flag_duplicate = F_flag.detach().clone().requires_grad_()\n",
    "    target = target_stress(F_flag, yield_strain)\n",
    "\n",
    "    # forward\n",
    "    sigma_out = model(F, F_flag, yield_strain)\n",
    "    flag_out = model(F_flag_duplicate, F_flag, yield_strain)\n",
    "\n",
    "    # grad\n",
    "    flag_grad = torch.autograd.grad(flag_out, F_flag_duplicate, create_graph=True, retain_graph=True, grad_outputs=torch.ones_like(flag_out))[0]\n",
    "    sigma_grad = torch.autograd.grad(sigma_out, F, create_graph=True, retain_graph=True, grad_outputs=torch.ones_like(sigma_out))[0]\n",
    "    \n",
    "    # collect\n",
    "    stress = sigma_grad - flag_grad + target\n",
    "    \n",
    "    return stress\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a126bba-adf3-418e-8664-63f170888fe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_weights(m):\n",
    "    if isinstance(m, nn.Linear):\n",
    "        torch.nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))\n",
    "        m.bias.data.fill_(0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "450ce8a6-bfb6-429f-a57f-064480318740",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = PsiNet()\n",
    "net = net.to(device)\n",
    "net.apply(init_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c23f245f-51c3-4e79-aadf-fa22acb20e5a",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cd354c8-88ac-44aa-bcca-1f21bb26d6d3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm, trange\n",
    "from torch.utils.tensorboard import SummaryWriter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cc95a04-ff2d-4124-ac1c-d7a12ff3f18b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_data(batch_size, hencky_perturb = 0.2):\n",
    "    sigma_flag = random_sigma(batch_size, device)\n",
    "    perturbed_hencky = torch.log(sigma_flag) + torch.from_numpy(np.random.uniform(-hencky_perturb, hencky_perturb, sigma_flag.shape)).float().to(device)\n",
    "    sigma = torch.exp(perturbed_hencky)\n",
    "    yield_strain = random_yield_strain(batch_size, device)\n",
    "    U1 = random_2d_rotation(batch_size, device)\n",
    "    U2 = random_2d_rotation(batch_size, device)\n",
    "    V1 = random_2d_rotation(batch_size, device)\n",
    "    V2 = random_2d_rotation(batch_size, device)\n",
    "    F_flag = U1 @ torch.diag_embed(sigma_flag) @ V1\n",
    "    F = U2 @ torch.diag_embed(sigma) @ V2\n",
    "    return F, F_flag, yield_strain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "135349e2-0730-494f-8f8a-e62ce3c9df5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_cost_function = torch.nn.MSELoss() # Mean squared error\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c5f8839-5601-4060-b9d9-cdf557979746",
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 20000\n",
    "batch_size = 2 ** 16\n",
    "hencky_perturb_scale = 0.1\n",
    "\n",
    "writer = SummaryWriter(comment=\"VonMises\")\n",
    "\n",
    "pbar = trange(iterations, unit=\"iters\")\n",
    "for epoch in pbar:\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    F, F_flag, yield_strain = prepare_data(batch_size, hencky_perturb_scale)\n",
    "    stress = training_output(net, F, F_flag, yield_strain)\n",
    "    target = target_stress(F, yield_strain)\n",
    "    \n",
    "    loss = mse_cost_function(stress, target)\n",
    "    loss.backward()\n",
    "    \n",
    "    optimizer.step()\n",
    "    if optimizer.param_groups[0]['lr'] > 1e-4:\n",
    "        scheduler.step()\n",
    "    \n",
    "    with torch.autograd.no_grad():\n",
    "        pbar.set_description(f\"Epoch: {epoch}, Traning Loss: {loss.item()}\")\n",
    "        writer.add_scalar('Tatal_Loss/train', loss.item(), epoch)\n",
    "        writer.add_scalar('lr/train', optimizer.param_groups[0]['lr'], epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea4b59c-5a41-4907-a3e9-921253ce4596",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(net.state_dict(), \"params/VM.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2c3c69c-6c9c-406f-91e4-704aa2409e05",
   "metadata": {
    "tags": []
   },
   "source": [
    "# save model\n",
    "\n",
    "Save as a script_module to be loaded in C++"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4f79d20-5ff5-4918-9a1f-c9a43fdc1a19",
   "metadata": {},
   "outputs": [],
   "source": [
    "F = torch.from_numpy(np.array([[[1, 0.1], [0.1, 1]]])).float().to(device)\n",
    "F_flag = torch.from_numpy(np.array([[[1, 0.1], [0.1, 1]]])).float().to(device)\n",
    "strain = torch.from_numpy(np.array([[0.1]])).float().to(device)\n",
    "traced_script_module = torch.jit.trace(net, (F, F_flag, strain))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05198783-da11-49e5-bf29-a2099c1b3db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "traced_script_module.save(\"traced_model_von_mises.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
