import time

import numpy as np
from dolfin import *
import cashocs


def l2_rel_error(u, u_ref):
    u = u.reshape(-1)
    u_ref = u_ref.reshape(-1)
    return np.sqrt(np.sum((u - u_ref)**2)) / np.sqrt(np.sum(u_ref**2))


def a_fun(xyt):
    x, y, _ = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3]
    return 2 + np.sin(np.pi * x) * np.sin(np.pi * y)


def u_fun(xyt):
    x, y, t = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3]
    return np.exp(-t) * np.sin(np.pi * x) * np.sin(np.pi * y)


def f_fun(xyt):
    x, y, t = xyt[:, 0:1], xyt[:, 1:2], xyt[:, 2:3]
    s, c, p = np.sin, np.cos, np.pi
    # fmt: off
    return np.exp(-t) * (
        (4 * p**2 - 1) * s(p * x) * s(p * y)
        + p**2 * (
            2 * s(p * x) ** 2 * s(p * y) ** 2
            - c(p * x) ** 2 * s(p * y) ** 2
            - s(p * x) ** 2 * c(p * y) ** 2
        )
    )


cashocs.set_log_level(cashocs.LogLevel.WARNING)

config = cashocs.load_config("scripts/config.ini")

mesh, subdomains, boundaries, dx, ds, dS = cashocs.regular_box_mesh(
    start_x=-1, start_y=-1, start_z=0, end_x=1, end_y=1, end_z=1,
    n=20
)

V = FunctionSpace(mesh, "CG", 1)
coords = V.tabulate_dof_coordinates()

f = Function(V)
f.vector()[:] = f_fun(coords).reshape(-1)

# Create u sample with noise N(0, 0.1)
u_ref = Function(V)
u_ref.vector()[:] = u_fun(coords).reshape(-1)
u_sample = Function(V)
u_sample.vector()[:] = u_ref.vector()[:] + \
    np.random.normal(0, 0.1, u_ref.vector().size())

# Create a reference
a_ref = Function(V)
a_ref.vector()[:] = a_fun(coords).reshape(-1)

bcs = cashocs.create_dirichlet_bcs(V, a_ref, boundaries, [1, 2, 3, 4])

times = []
l2res = []

for _ in range(5):
    start = time.time()
    a = Function(V)
    v = Function(V)
    u = Function(V)

    # The weak form of the state equation,
    grad_u = as_vector((u.dx(0), u.dx(1)))
    grad_v = as_vector((v.dx(0), v.dx(1)))
    u_t = u.dx(2)
    e = u_t*v*dx + a*inner(grad_u, grad_v)*dx - f * v * dx

    # and then define the cost functional as
    J = cashocs.IntegralFunctional(
        Constant(0.5) * (u - u_sample) * (u - u_sample) * dx
    )

    ocp = cashocs.OptimalControlProblem(
        e, bcs, J, a, u, v, 
        config=config)
    
    ocp.solve("bfgs", rtol=1e-6, atol=0, max_iter=100)

    end = time.time()
    times.append(end - start)
    l2res.append(l2_rel_error(
        np.concatenate((a.vector()[:], u.vector()[:])), 
        np.concatenate((a_ref.vector()[:], u_ref.vector()[:]))))

print("Average Time elapsed: %.2e" % (np.mean(times)))
print("L2-error = %.2e ± %.2e" % (np.mean(l2res), np.std(l2res)))
