{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ee7061-8f7c-4d64-bbe0-15561238b4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "device = torch.device(\"cuda:0\") \n",
    "import numpy as np\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef46e355-ec62-4cd2-902f-091fb79fc218",
   "metadata": {},
   "outputs": [],
   "source": [
    "# constant physcial parameters\n",
    "dim = 2\n",
    "E = 100\n",
    "nu = 0.3\n",
    "mu = 0.5 * E / (1 + nu)\n",
    "lam = E * nu / ((1 + nu) * (1 - 2 * nu))\n",
    "kappa = E / (3 * (1 - 2 * nu))\n",
    "beta = .3\n",
    "xi = .5\n",
    "friction_angle = 45\n",
    "sin_phi = np.sin(friction_angle / 180 * np.pi)\n",
    "mohr_columb_friction = np.sqrt(2. / 3.) * 2. * sin_phi / (3. - sin_phi)\n",
    "M = mohr_columb_friction * dim / np.sqrt(2 / (6 - dim))\n",
    "print(mu)\n",
    "print(lam)\n",
    "print(kappa)\n",
    "y_max = 2 * mu ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d10aef3-41a5-4d73-ad16-8df6f26bbac9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def deviatoric(input_tensor):\n",
    "    return input_tensor - torch.mean(input_tensor, dim=-1, keepdim=True)\n",
    "\n",
    "def kirchoff_stress(sigma):\n",
    "    B_hat_trial = sigma ** 2\n",
    "    J = torch.prod(sigma, dim=-1, keepdim = True)\n",
    "    prime = 0.5 * kappa * (J - 1. / J)\n",
    "    tau = mu * torch.pow(J, -2. / dim) * deviatoric(B_hat_trial) + J * prime\n",
    "    return tau\n",
    "\n",
    "def analytical_projection_sigma(sigma, logJp):\n",
    "    p0 = kappa * (1e-6 + torch.sinh(xi * torch.clamp(-logJp, min=0.)))\n",
    "    J = torch.prod(sigma, dim=-1, keepdim = True)\n",
    "    B_hat_trial = sigma ** 2\n",
    "    prime = 0.5 * kappa * (J - 1. / J)\n",
    "    p_trial = -prime * J\n",
    "    \n",
    "    pMin = beta * p0\n",
    "    pMax = p0\n",
    "    \n",
    "    case1_Je_new = torch.sqrt(-2. * pMax / kappa + 1)\n",
    "    case1_sigma = torch.tile(torch.pow(case1_Je_new, 1 / dim), (1, dim))\n",
    "    \n",
    "    case2_Je_new = torch.sqrt(2. * pMin / kappa + 1)\n",
    "    case2_sigma = torch.tile(torch.pow(case2_Je_new, 1 / dim), (1, dim))\n",
    "\n",
    "    y_s_half_coeff = 0.5 * (6 - dim) * (1. + 2. * beta)\n",
    "    y_p_half = M * M * (p_trial + pMin) * (p_trial - pMax)\n",
    "    \n",
    "    s_hat_trial = mu * torch.pow(J, -2. / dim) * deviatoric(B_hat_trial)\n",
    "    s_hat_trail_squared_norm = torch.sum(s_hat_trial ** 2, dim=-1, keepdim=True)\n",
    "    s_hat_trial_norm = torch.sqrt(s_hat_trail_squared_norm + 1e-12)\n",
    "    \n",
    "    y = y_s_half_coeff * s_hat_trail_squared_norm + y_p_half\n",
    "    \n",
    "    s_new_norm = torch.pow(J, 2. / dim) / mu * torch.sqrt(-y_p_half / y_s_half_coeff)\n",
    "    B_hat_dev = s_new_norm / s_hat_trial_norm * s_hat_trial\n",
    "    B_hat_new = B_hat_dev + B_hat_trial.mean(dim=-1, keepdim=True)\n",
    "    \n",
    "    case3_sigma = torch.where(y > y_max, torch.ones_like(case2_sigma), torch.sqrt(B_hat_new))\n",
    "    \n",
    "    new_sigma = torch.where(p_trial > pMax, case1_sigma, \n",
    "                            torch.where(p_trial < -pMin, case2_sigma, \n",
    "                                torch.where(y < 1e-12, sigma, case3_sigma)\n",
    "                                       )\n",
    "                           )\n",
    "    \n",
    "    return new_sigma\n",
    "    \n",
    "    \n",
    "def analytical_projection(F, logJp):\n",
    "    U, sigma, Vh = torch.linalg.svd(F)\n",
    "    Z = analytical_projection_sigma(sigma, logJp)\n",
    "    return U @ torch.diag_embed(Z) @ Vh, valid\n",
    "\n",
    "def target_stress(F, logJp):\n",
    "    U, sigma, Vh = torch.linalg.svd(F)\n",
    "    Z = analytical_projection_sigma(sigma, logJp)\n",
    "    tau = kirchoff_stress(Z)\n",
    "    grad = tau / sigma\n",
    "    stress = U @ torch.diag_embed(grad) @ Vh\n",
    "    return stress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88912318-a3b0-433e-84df-ca18012b2f86",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_angle(batch_size, device='cpu'):\n",
    "    return torch.from_numpy(np.random.uniform(size=[batch_size, 1], low=-np.pi, high=np.pi)).float().to(device)\n",
    "\n",
    "def random_2d_rotation(batch_size, device='cpu'):\n",
    "    theta = random_angle(batch_size, device)\n",
    "    R0 = torch.cat([torch.cos(theta)[:, :, None], torch.sin(theta)[:, :, None]], dim=-1)\n",
    "    R1 = torch.cat([-torch.sin(theta)[:, :, None], torch.cos(theta)[:, :, None]], dim=-1)\n",
    "    R = torch.cat([R0, R1], dim=-2)\n",
    "    return R\n",
    "\n",
    "def random_sigma(batch_size, device='cpu'):\n",
    "    return torch.exp(torch.from_numpy(np.random.uniform(size=[batch_size, dim], low=-1, high=1)).float().to(device))\n",
    "\n",
    "def random_logJp(batch_size, device='cpu'):\n",
    "    return torch.from_numpy(np.random.uniform(size=[batch_size, 1], low=-0.5, high=0)).float().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b2f7928-3f41-4a7c-95b7-8a7440f7e067",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Activation(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x * torch.sigmoid(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "883e22a8-164c-4b47-ae3b-1401821a2cad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hidden_dim = 32\n",
    "class PsiNet(nn.Module):\n",
    "    training = True\n",
    "    def __init__(self):\n",
    "        super(PsiNet, self).__init__()\n",
    "        self.layers = nn.ModuleList([\n",
    "            nn.Linear(dim * dim * 2 + 1, hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,hidden_dim),\n",
    "            Activation(),\n",
    "            nn.Linear(hidden_dim,1)\n",
    "        ])\n",
    "    def forward(self, F, F_flag, logJp):\n",
    "        input = torch.cat([F.flatten(1), F_flag.flatten(1), logJp],axis=-1)\n",
    "        out = input\n",
    "        for i, l in enumerate(self.layers):\n",
    "            out = l(out)\n",
    "        return out\n",
    "    \n",
    "def training_output(model, F_, F_flag, logJp):\n",
    "    F_flag.requires_grad_()\n",
    "    F = F_.detach().clone().requires_grad_()\n",
    "    F_flag_duplicate = F_flag.detach().clone().requires_grad_()\n",
    "    target = target_stress(F_flag, logJp)\n",
    "\n",
    "    # forward\n",
    "    sigma_out = model(F, F_flag, logJp)\n",
    "    flag_out = model(F_flag_duplicate, F_flag, logJp)\n",
    "\n",
    "    # grad\n",
    "    flag_grad = torch.autograd.grad(flag_out, F_flag_duplicate, create_graph=True, retain_graph=True, grad_outputs=torch.ones_like(flag_out))[0]\n",
    "    sigma_grad = torch.autograd.grad(sigma_out, F, create_graph=True, retain_graph=True, grad_outputs=torch.ones_like(sigma_out))[0]\n",
    "    \n",
    "    # collect\n",
    "    stress = sigma_grad - flag_grad + target\n",
    "\n",
    "    return stress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a126bba-adf3-418e-8664-63f170888fe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_weights(m):\n",
    "    if isinstance(m, nn.Linear):\n",
    "        torch.nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))\n",
    "        m.bias.data.fill_(0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "450ce8a6-bfb6-429f-a57f-064480318740",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = PsiNet()\n",
    "net = net.to(device)\n",
    "net.apply(init_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c23f245f-51c3-4e79-aadf-fa22acb20e5a",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Training\n",
    "```\n",
    "pip install tensorboard\n",
    "tensorboard --logdir=runs --bind_all\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cd354c8-88ac-44aa-bcca-1f21bb26d6d3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm, trange\n",
    "from torch.utils.tensorboard import SummaryWriter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cc95a04-ff2d-4124-ac1c-d7a12ff3f18b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_data(batch_size, hencky_perturb = 0.2):\n",
    "    sigma_flag = random_sigma(batch_size, device)\n",
    "    perturbed_hencky = torch.log(sigma_flag) + torch.from_numpy(np.random.uniform(-hencky_perturb, hencky_perturb, sigma_flag.shape)).float().to(device)\n",
    "    sigma = torch.exp(perturbed_hencky)\n",
    "    logJp = random_logJp(batch_size, device)\n",
    "    U1 = random_2d_rotation(batch_size, device)\n",
    "    U2 = random_2d_rotation(batch_size, device)\n",
    "    V1 = random_2d_rotation(batch_size, device)\n",
    "    V2 = random_2d_rotation(batch_size, device)\n",
    "    F_flag = U1 @ torch.diag_embed(sigma_flag) @ V1\n",
    "    F = U2 @ torch.diag_embed(sigma) @ V2\n",
    "    return F, F_flag, logJp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "135349e2-0730-494f-8f8a-e62ce3c9df5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_cost_function = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c5f8839-5601-4060-b9d9-cdf557979746",
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 20000\n",
    "batch_size = 2 ** 16\n",
    "hencky_perturb_scale = 0.2\n",
    "\n",
    "writer = SummaryWriter(comment=\"CamClay\")\n",
    "\n",
    "pbar = trange(iterations, unit=\"iters\")\n",
    "for epoch in pbar:\n",
    "    optimizer.zero_grad()\n",
    "    F, F_flag, logJp = prepare_data(batch_size, hencky_perturb_scale)\n",
    "    stress = training_output(net, F, F_flag, logJp)\n",
    "    target = target_stress(F, logJp)\n",
    "    loss = mse_cost_function(stress, target)\n",
    "    loss.backward()\n",
    "    \n",
    "    optimizer.step()\n",
    "    if optimizer.param_groups[0]['lr'] > 1e-4:\n",
    "        scheduler.step()\n",
    "    \n",
    "    with torch.autograd.no_grad():\n",
    "        pbar.set_description(f\"Epoch: {epoch}, Traning Loss: {loss.item()}\")\n",
    "        writer.add_scalar('Tatal_Loss/train', loss.item(), epoch)\n",
    "        writer.add_scalar('lr/train', optimizer.param_groups[0]['lr'], epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea4b59c-5a41-4907-a3e9-921253ce4596",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(net.state_dict(), \"params/CC.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27feb458-ec85-47ed-82da-af4fa1ed783f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# net = PsiNet()\n",
    "# net = net.to(device)\n",
    "# net.load_state_dict(torch.load(\"params/CC.pt\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2c3c69c-6c9c-406f-91e4-704aa2409e05",
   "metadata": {
    "tags": []
   },
   "source": [
    "# save model\n",
    "\n",
    "Save as a script_module to be loaded in C++"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4f79d20-5ff5-4918-9a1f-c9a43fdc1a19",
   "metadata": {},
   "outputs": [],
   "source": [
    "F = torch.from_numpy(np.array([[[1, 0.1], [0.1, 1]]])).float().to(device)\n",
    "F_flag = torch.from_numpy(np.array([[[1, 0.1], [0.1, 1]]])).float().to(device)\n",
    "strain = torch.from_numpy(np.array([[-0.1]])).float().to(device)\n",
    "traced_script_module = torch.jit.trace(net, (F, F_flag, strain))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05198783-da11-49e5-bf29-a2099c1b3db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "traced_script_module.save(\"traced_model_cam_clay.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
