{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from firedrake import *\n",
    "#from dataset_processing.generate_random_conductivity import random_field\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh = UnitSquareMesh(50, 50)\n",
    "V = FunctionSpace(mesh, \"CG\", 1)\n",
    "v = TestFunction(V)\n",
    "\n",
    "x, y = SpatialCoordinate(mesh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy.random import default_rng\n",
    "\n",
    "\n",
    "def random_field(V, N=1, m=15, σ=1.4, seed=2023):\n",
    "    # Generate 2D random field with m modes\n",
    "    rng = default_rng(seed)\n",
    "    x, y = SpatialCoordinate(V.ufl_domain())\n",
    "    fields = []\n",
    "    for _ in range(N):\n",
    "        r = 0\n",
    "        for _ in range(m):\n",
    "            a, b = rng.standard_normal(2)\n",
    "            k1, k2 = rng.normal(0, σ, 2)\n",
    "            θ = 2 * pi * (k1 * x + k2 * y)\n",
    "            r += Constant(a) * cos(θ) + Constant(b) * sin(θ)\n",
    "        fields.append(interpolate(sqrt(1 / m) * r, V))\n",
    "    return fields"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k_exact, = random_field(V)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = Function(V).interpolate(sin(pi * x) * sin(pi * y))\n",
    "u_exact = Function(V)\n",
    "F = (inner(exp(k_exact) * grad(u_exact), grad(v)) - inner(f, v)) * dx\n",
    "bcs = [DirichletBC(V, Constant(0.0), \"on_boundary\")]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_poisson(solver_parameters, u, k, v, f, bcs):\n",
    "    # Solve PDE\n",
    "    F = (inner(exp(k) * grad(u), grad(v)) - inner(f, v)) * dx\n",
    "    solve(F == 0, u, bcs=bcs, solver_parameters=solver_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def time_poisson(solver_parameters):\n",
    "    mesh = UnitSquareMesh(50, 50)\n",
    "    V = FunctionSpace(mesh, \"CG\", 1)\n",
    "    v = TestFunction(V)\n",
    "\n",
    "    x, y = SpatialCoordinate(mesh)\n",
    "    print('Hello1')\n",
    "    f = Function(V).interpolate(sin(pi * x) * sin(pi * y))\n",
    "    u = Function(V)\n",
    "    bcs = [DirichletBC(V, Constant(0.0), \"on_boundary\")]\n",
    "\n",
    "    cpu_t = time.time()\n",
    "    t = []\n",
    "    print('Hello')\n",
    "    for _ in range(10):\n",
    "        temp_t = time.time()\n",
    "        # Solve PDE\n",
    "        solve_poisson(solver_parameters, u, k, v, f, bcs)\n",
    "        # Time temp\n",
    "        temp_t = time.time() - temp_t\n",
    "        t.append(temp_t)\n",
    "        print(temp_t)\n",
    "    print(f\"Elapse time: {time.time() - cpu_t}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Removing cached TSFC kernels from /home/nacime/firedrake/.cache/tsfc\r\n",
      "Removing cached PyOP2 code from /home/nacime/firedrake/.cache/pyop2\r\n",
      "Removing cached pytools files from /home/nacime/.cache/pytools\r\n"
     ]
    }
   ],
   "source": [
    "!firedrake-clean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/nacime/firedrake/lib/python3.8/site-packages/pytools/persistent_dict.py:520: UserWarning: could not obtain lock -- delete '/home/nacime/.cache/pytools/pdict-v4-loopy-schedule-cache-v4-2022.1-islpy2022.2.1-cgen2020.1-3988272b385fd770a3427b539e8b27e367f9db33-v1-py3.8.16.final.0/80723ff9814a050c7ed20e3b81fbe629ae2ce6cfad2490e0e7811a9d948c1cfe.lock' if necessary\n",
      "  self.store(key, value, _skip_if_present=True, _stacklevel=1 + _stacklevel)\n"
     ]
    }
   ],
   "source": [
    "time_poisson(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cpu_t = time.time()\n",
    "t = []\n",
    "for _ in range(10):\n",
    "    temp_t = time.time()\n",
    "\n",
    "    # Solve PDE\n",
    "    F = (inner(exp(k_exact) * grad(u_exact), grad(v)) - inner(f, v)) * dx\n",
    "    solve(F == 0, u_exact, bcs=bcs)\n",
    "\n",
    "    temp_t = time.time() - temp_t\n",
    "    t.append(temp_t)\n",
    "    print(temp_t)\n",
    "print(f\"Elapse time: {time.time() - cpu_t}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cpu_t = time.time()\n",
    "t = []\n",
    "\n",
    "u = TrialFunction(V)\n",
    "for _ in range(10):\n",
    "    temp_t = time.time()\n",
    "\n",
    "    # Solve PDE\n",
    "    w = Function(V)\n",
    "    F = (inner(exp(k_exact) * grad(u), grad(v)) - inner(f, v)) * dx\n",
    "    A, b = lhs(F), rhs(F)\n",
    "    vpb = LinearVariationalProblem(A, b, w, bcs=bcs)\n",
    "    solver =  LinearVariationalSolver(vpb)\n",
    "    solver.solve()\n",
    "\n",
    "    temp_t = time.time() - temp_t\n",
    "    t.append(temp_t)\n",
    "    print(temp_t)\n",
    "print(f\"Elapse time: {time.time() - cpu_t}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "k_exact = random_field(V, N=15)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
    "collection = tripcolor(k_exact, axes=axes[0], alpha=1, cmap='jet')\n",
    "fig.colorbar(collection);\n",
    "collection = tricontour(k_exact, axes=axes[1], alpha=1)\n",
    "fig.colorbar(collection);\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = Function(V).interpolate(sin(pi * x) * sin(pi * y))\n",
    "# f = Function(V).assign(1)\n",
    "\n",
    "u_exact = Function(V)\n",
    "F = (inner(exp(k_exact) * grad(u_exact), grad(v)) - inner(f, v)) * dx\n",
    "bcs = [DirichletBC(V, Constant(0.0), \"on_boundary\")]\n",
    "# Solve PDE\n",
    "solve(F == 0, u_exact, bcs=bcs)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
    "collection = tripcolor(u_exact, axes=axes[0], alpha=1, cmap='jet')\n",
    "fig.colorbar(collection);\n",
    "collection = tricontour(u_exact, axes=axes[1], alpha=1)\n",
    "fig.colorbar(collection);\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "F = (inner(exp(k_exact) * grad(u_exact), grad(v)) - inner(f, v)) * dx\n",
    "assemble(F, bcs=bcs).dat._vec.norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = Function(V).interpolate(2 * pi ** 2 * sin(pi * x) * sin(pi * y))\n",
    "# f = Function(V).assign(1)\n",
    "\n",
    "u_exact = Function(V)\n",
    "F = (inner( grad(u_exact), grad(v)) - inner(f, v)) * dx\n",
    "bcs = [DirichletBC(V, Constant(0.0), \"on_boundary\")]\n",
    "# Solve PDE\n",
    "solve(F == 0, u_exact, bcs=bcs)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
    "collection = tripcolor(u_exact, axes=axes[0], alpha=1, cmap='jet')\n",
    "fig.colorbar(collection);\n",
    "collection = tricontour(u_exact, axes=axes[1], alpha=1)\n",
    "fig.colorbar(collection);\n",
    "plt.show()\n",
    "\n",
    "t = Function(V).interpolate(sin(pi * x) * sin(pi * y))\n",
    "print(assemble((u_exact-t)**2 * dx))\n",
    "F = (inner(grad(u_exact), grad(v)) - inner(f, v)) * dx\n",
    "assemble(F, bcs=bcs).dat._vec.norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cs = [44, 45, 62, 63, 50, 51, 84, 85, 124, 125, 140, 141]\n",
    "\n",
    "n = 3\n",
    "m = 4\n",
    "fig, axes = plt.subplots(n, m, figsize=(15, 15))\n",
    "for i, cmap in enumerate(cs):\n",
    "    ax = axes[int(i/m), i%m]\n",
    "    collection = tripcolor(k_exact, axes=ax, alpha=1, cmap=cmaps[cmap])\n",
    "    ax.set_title(cmaps[cmap])\n",
    "    fig.colorbar(collection);\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = random_field(V, N=15)\n",
    "tripcolor(k, cmap='jet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "firedrake",
   "language": "python",
   "name": "firedrake"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
