from dolfin import *
import numpy as np


def l2_rel_error(u, u_ref):
    u = u.reshape(-1)
    u_ref = u_ref.reshape(-1)
    return np.sqrt(np.sum((u - u_ref)**2)) / np.sqrt(np.sum(u_ref**2))


def a_fun(xyt):
    x, y, _ = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3]
    return 2 + np.sin(np.pi * x) * np.sin(np.pi * y)

def u_fun(xyt):
    x, y, t = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3]
    return np.exp(-t) * np.sin(np.pi * x) * np.sin(np.pi * y)

def f_fun(xyt):
    x, y, t = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3]
    s, c, p = np.sin, np.cos, np.pi
    # fmt: off
    return np.exp(-t) * (
        (4 * p**2 - 1) * s(p * x) * s(p * y)
        + p**2 * (
            2 * s(p * x) ** 2 * s(p * y) ** 2
            - c(p * x) ** 2 * s(p * y) ** 2
            - s(p * x) ** 2 * c(p * y) ** 2
        )
    )

# Create mesh and define function space
bbox = [-1, 1, -1, 1, 0, 1]
mesh = BoxMesh(Point(bbox[0], bbox[2], bbox[4]), 
    Point(bbox[1], bbox[3], bbox[5]), 
    40,
    40,
    10)
V = FunctionSpace(mesh, "Lagrange", 1)

coords = V.tabulate_dof_coordinates()

c_ref = Function(V)
c_ref.vector()[:] = a_fun(coords).reshape(-1)
u_ref = Function(V)
u_ref.vector()[:] = u_fun(coords).reshape(-1)
f = Function(V)
f.vector()[:] = f_fun(coords).reshape(-1)

# Define Dirichlet boundary
def boundary(_, on_boundary):
    return on_boundary

# Define boundary condition
bc = DirichletBC(V, u_ref, boundary)

# Define variational problem
u = TrialFunction(V)
v = TestFunction(V)
grad_u = as_vector((u.dx(0), u.dx(1)))
grad_v = as_vector((v.dx(0), v.dx(1)))
u_t = u.dx(2)

# c = TrialFunction(V)
a = u_t*v*dx + c_ref*inner(grad_u, grad_v)*dx
L = f*v*dx

# Compute solution
u = Function(V)
solve(a == L, u, bc, solver_parameters={"linear_solver": "mumps"})

# Print error
print("L2-error = %e" % l2_rel_error(u.vector()[:], u_ref.vector()[:]))