{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ee7061-8f7c-4d64-bbe0-15561238b4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "device = torch.device(\"cuda:0\")\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef46e355-ec62-4cd2-902f-091fb79fc218",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 2\n",
    "E = 10\n",
    "nu = 0.3\n",
    "mu = 0.5 * E / (1 + nu)\n",
    "lam = E * nu / ((1 + nu) * (1 - 2 * nu))\n",
    "friction_angle = 30\n",
    "sin_phi = np.sin(friction_angle / 180 * np.pi)\n",
    "alpha = np.sqrt(2. / 3.) * 2. * sin_phi / (3. - sin_phi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66edbae6-1edb-4f51-b98d-1ed90ab2a2d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_epsilon_hat_norm(sigma):\n",
    "    epsilon = torch.log(sigma)\n",
    "    trace_epsilon = torch.sum(epsilon, dim=-1, keepdim=True)\n",
    "    epsilon_hat = epsilon - trace_epsilon / dim\n",
    "    return torch.linalg.norm(epsilon_hat, dim=-1, keepdim=True)\n",
    "\n",
    "def kirchoff_stress(sigma):\n",
    "    log_sigma_prod = torch.sum(torch.log(sigma), dim=-1, keepdim=True)\n",
    "    stress = mu * (sigma ** 2 - 1) + lam * log_sigma_prod\n",
    "    return stress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5fff823-3b54-448b-a972-2ebaf3cfec2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def yield_criterion(sigma, yield_strain):\n",
    "    trace_epsilon = torch.sum(torch.log(sigma), dim=-1, keepdim=True)\n",
    "    tau = kirchoff_stress(sigma)\n",
    "    trace_tau = torch.sum(tau, dim=-1, keepdim=True)\n",
    "    y = torch.sum(torch.pow(tau - trace_tau / dim, 2), dim=-1, keepdim=True) - (2 * mu * yield_strain)**2\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23163fae-9b13-43f7-ba6c-51e640ef2225",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_angle(batch_size, device='cpu'):\n",
    "    return torch.from_numpy(np.random.uniform(size=[batch_size, 1], low=-np.pi, high=np.pi)).float().to(device)\n",
    "\n",
    "def random_2d_rotation(batch_size, device='cpu'):\n",
    "    theta = random_angle(batch_size, device)\n",
    "    R0 = torch.cat([torch.cos(theta)[:, :, None], torch.sin(theta)[:, :, None]], dim=-1)\n",
    "    R1 = torch.cat([-torch.sin(theta)[:, :, None], torch.cos(theta)[:, :, None]], dim=-1)\n",
    "    R = torch.cat([R0, R1], dim=-2)\n",
    "    return R\n",
    "\n",
    "def random_yield_strain(batch_size, device='cpu'):\n",
    "    return torch.from_numpy(np.random.uniform(size=[batch_size, 1], low=0, high=0.1)).float().to(device)\n",
    "\n",
    "def random_sigma(batch_size, device='cpu', bound=1):\n",
    "    return torch.exp(torch.from_numpy(np.random.uniform(size=[batch_size, dim], low=-bound, high=bound)).float().to(device))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94eb72a6-2e61-4698-be34-f8424ed397fe",
   "metadata": {},
   "source": [
    "# Return mapping Net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfa8dc1c-f77c-42c0-ba96-d9047cbc5f3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Activation(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x * torch.sigmoid(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12c27c2e-c542-4f2c-8d2e-ace3b7e80c48",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GammaNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GammaNet, self).__init__()\n",
    "        hidden_dim = 16\n",
    "        self.model = nn.ModuleList([\n",
    "            nn.Linear(dim + 1, hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,1)\n",
    "        ])\n",
    "        \n",
    "    def forward(self, sigma, yield_strain):\n",
    "        delta_gamma = torch.concat([sigma, yield_strain], axis=-1)\n",
    "        for i, l in enumerate(self.model):\n",
    "            delta_gamma = l(delta_gamma)\n",
    "        return delta_gamma\n",
    "    \n",
    "    def project_strain(self, sigma, yield_strain):\n",
    "        y = yield_criterion(sigma, yield_strain)\n",
    "        epsilon = torch.log(sigma)\n",
    "        trace_epsilon = torch.sum(epsilon, dim=-1, keepdim=True)\n",
    "        epsilon_hat = epsilon - trace_epsilon / dim\n",
    "        epsilon_hat_norm = torch.sqrt(torch.sum(epsilon_hat**2, dim=-1, keepdim=True) + 1e-12)\n",
    "        s_hat = epsilon_hat / epsilon_hat_norm\n",
    "        delta_gamma = self.forward(sigma, yield_strain)\n",
    "        H = epsilon - torch.minimum(delta_gamma, epsilon_hat_norm) * s_hat\n",
    "        H = torch.where(y > 0, H, epsilon)\n",
    "        return torch.exp(H)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63202a7a-df09-47dc-af71-1906ad1779d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_weights(m):\n",
    "    if isinstance(m, nn.Linear):\n",
    "        torch.nn.init.uniform_(m.weight, -0.01, 0.01)\n",
    "        m.bias.data.fill_(0)\n",
    "gamma_net = GammaNet()\n",
    "gamma_net = gamma_net.to(device)\n",
    "gamma_net.apply(init_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc902f69-4ff2-47a5-b37c-897dd8b02148",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm, trange\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "mse_cost_function = torch.nn.MSELoss() # Mean squared error\n",
    "optimizer = torch.optim.Adam(gamma_net.parameters(), lr=1e-2)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d10aef3-41a5-4d73-ad16-8df6f26bbac9",
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 100000\n",
    "batch_size = 2 ** 16\n",
    "gamma_net = gamma_net.double()\n",
    "\n",
    "pbar = trange(iterations, unit=\"iters\")\n",
    "for epoch in pbar:\n",
    "    optimizer.zero_grad()\n",
    "    sigma = random_sigma(batch_size, device, 1).double()\n",
    "    yield_strain = random_yield_strain(batch_size, device).double()\n",
    "    y_sigma = yield_criterion(sigma, yield_strain)\n",
    "    delta_gamma = gamma_net(sigma, yield_strain)\n",
    "    Z = gamma_net.project_strain(sigma, yield_strain)\n",
    "    y = yield_criterion(Z, yield_strain)\n",
    "    epsilon_hat_norm = compute_epsilon_hat_norm(sigma)\n",
    "    y_loss = torch.mean(torch.where(y_sigma > 0, y ** 2, torch.zeros_like(y_sigma)))\n",
    "    gamma_loss = torch.mean(torch.where((y_sigma > 0), torch.clamp(delta_gamma - epsilon_hat_norm, 0), torch.zeros_like(delta_gamma)))\n",
    "    loss = y_loss + 100 * gamma_loss\n",
    "    loss.backward()\n",
    "    \n",
    "    optimizer.step()\n",
    "    if optimizer.param_groups[0]['lr'] > 1e-4:\n",
    "        scheduler.step()\n",
    "    \n",
    "    with torch.autograd.no_grad():\n",
    "        pbar.set_description(f\"Epoch: {epoch}, y_loss: {y_loss.item()}, gamma_loss: {gamma_loss.item()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "505f8ac5-7ee4-4b35-b3cd-2647c16d7319",
   "metadata": {},
   "outputs": [],
   "source": [
    "gamma_net.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "709a5aba-f0fb-4407-9048-5702107df3d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_stress(F, yield_strain):\n",
    "    U, sigma, Vh = torch.linalg.svd(F)\n",
    "    Z = gamma_net.project_strain(sigma, yield_strain)\n",
    "    tau = kirchoff_stress(Z)\n",
    "    grad = tau / sigma\n",
    "    stress = U @ torch.diag_embed(grad) @ Vh\n",
    "    return stress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf73874b-f124-4780-bcef-a1b2a1ce3ace",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eye_like(tensor):\n",
    "    return torch.tile(torch.eye(dim, device=tensor.device), (tensor.shape[0], 1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "883e22a8-164c-4b47-ae3b-1401821a2cad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hidden_dim = 32\n",
    "class PsiNet(nn.Module):\n",
    "    training = True\n",
    "    def __init__(self):\n",
    "        super(PsiNet, self).__init__()\n",
    "        self.layers = nn.ModuleList([\n",
    "            nn.Linear(dim * dim * 2 + 1, hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,1)\n",
    "        ])\n",
    "    def forward(self, F, F_flag, yield_strain):\n",
    "        input = torch.cat([F.flatten(1), F_flag.flatten(1), yield_strain],axis=-1)\n",
    "        out = input\n",
    "        for i, l in enumerate(self.layers):\n",
    "            out = l(out)\n",
    "        return out\n",
    "    \n",
    "def training_output(model, F_, F_flag, yield_strain):\n",
    "    F_flag.requires_grad_()\n",
    "    F = F_.detach().clone().requires_grad_()\n",
    "    F_flag_duplicate = F_flag.detach().clone().requires_grad_()\n",
    "    target = target_stress(F_flag, yield_strain)\n",
    "\n",
    "    # forward\n",
    "    sigma_out = model(F, F_flag, yield_strain)\n",
    "    flag_out = model(F_flag_duplicate, F_flag, yield_strain)\n",
    "\n",
    "    # grad\n",
    "    flag_grad = torch.autograd.grad(flag_out, F_flag_duplicate, create_graph=True, retain_graph=True, grad_outputs=torch.ones_like(flag_out))[0]\n",
    "    sigma_grad = torch.autograd.grad(sigma_out, F, create_graph=True, retain_graph=True, grad_outputs=torch.ones_like(sigma_out))[0]\n",
    "\n",
    "    # collect\n",
    "    stress = sigma_grad - flag_grad + target\n",
    "\n",
    "    return stress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "223c2121-da86-4b1c-8f61-38d7c7a59c8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_weights(m):\n",
    "    if isinstance(m, nn.Linear):\n",
    "        torch.nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))\n",
    "        m.bias.data.fill_(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "450ce8a6-bfb6-429f-a57f-064480318740",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = PsiNet()\n",
    "net = net.to(device)\n",
    "net.apply(init_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c23f245f-51c3-4e79-aadf-fa22acb20e5a",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cd354c8-88ac-44aa-bcca-1f21bb26d6d3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm, trange\n",
    "from torch.utils.tensorboard import SummaryWriter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0281d77-6037-46be-88dc-60e98828e850",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_data(batch_size, hencky_perturb = 0.2):\n",
    "    sigma_flag = random_sigma(batch_size, device)\n",
    "    perturbed_hencky = torch.log(sigma_flag) + torch.from_numpy(np.random.uniform(-hencky_perturb, hencky_perturb, sigma_flag.shape)).float().to(device)\n",
    "    sigma = torch.exp(perturbed_hencky)\n",
    "    yield_strain = random_yield_strain(batch_size, device)\n",
    "    U1 = random_2d_rotation(batch_size, device)\n",
    "    U2 = random_2d_rotation(batch_size, device)\n",
    "    V1 = random_2d_rotation(batch_size, device)\n",
    "    V2 = random_2d_rotation(batch_size, device)\n",
    "    F_flag = U1 @ torch.diag_embed(sigma_flag) @ V1\n",
    "    F = U2 @ torch.diag_embed(sigma) @ V2\n",
    "    return F, F_flag, yield_strain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc327fb4-7b0e-4b84-b848-68bb302b2ee6",
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_cost_function = torch.nn.MSELoss() # Mean squared error\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c5f8839-5601-4060-b9d9-cdf557979746",
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 20000\n",
    "batch_size = 2 ** 16\n",
    "hencky_perturb_scale = 0.1\n",
    "\n",
    "writer = SummaryWriter(comment=\"neohookean_vonmises\")\n",
    "\n",
    "pbar = trange(iterations, unit=\"iters\")\n",
    "for epoch in pbar:\n",
    "    optimizer.zero_grad()\n",
    "    F, F_flag, yield_strain = prepare_data(batch_size, hencky_perturb_scale)\n",
    "    stress = training_output(net, F, F_flag, yield_strain)\n",
    "    target = target_stress(F, yield_strain)\n",
    "    \n",
    "    loss = mse_cost_function(stress, target)\n",
    "    loss.backward() # This is for computing gradients using backward propagation\n",
    "    \n",
    "    optimizer.step() # This is equivalent to : theta_new = theta_old - alpha * derivative of J w.r.t theta\n",
    "    if optimizer.param_groups[0]['lr'] > 1e-4:\n",
    "        scheduler.step()\n",
    "    \n",
    "    with torch.autograd.no_grad():\n",
    "        pbar.set_description(f\"Epoch: {epoch}, Traning Loss: {loss.item()}\")\n",
    "        writer.add_scalar('Tatal_Loss/train', loss.item(), epoch)\n",
    "        writer.add_scalar('lr/train', optimizer.param_groups[0]['lr'], epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4be6ebfa-86ed-41fa-bae1-bdf74a58276f",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(net.state_dict(), \"params/NK_VM.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2c3c69c-6c9c-406f-91e4-704aa2409e05",
   "metadata": {},
   "source": [
    "# save model\n",
    "\n",
    "Save as a script_module to be loaded in C++"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4f79d20-5ff5-4918-9a1f-c9a43fdc1a19",
   "metadata": {},
   "outputs": [],
   "source": [
    "F = torch.from_numpy(np.array([[[1, 0.1], [0.1, 1]]])).float().to(device)\n",
    "F_flag = torch.from_numpy(np.array([[[1, 0.1], [0.1, 1]]])).float().to(device)\n",
    "yield_strain = torch.from_numpy(np.array([[0.5]])).float().to(device)\n",
    "traced_script_module = torch.jit.trace(net.forward, (F, F_flag, yield_strain))\n",
    "traced_script_module.save(\"neohookean_von_mises.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc55f61e-ea95-4a67-8ca9-da545a95188d",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = torch.from_numpy(np.array([[0.9, 0.8]])).float().to(device)\n",
    "traced_script_module2 = torch.jit.trace(gamma_net, (sigma, yield_strain))\n",
    "traced_script_module2.save(\"neohookean_von_mises_return_mapping.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
