import time

import numpy as np
from dolfin import *
import cashocs


def l2_rel_error(u, u_ref):
    u = u.reshape(-1)
    u_ref = u_ref.reshape(-1)
    return np.sqrt(np.sum((u - u_ref)**2)) / np.sqrt(np.sum(u_ref**2))


def a_fun(xy):
    x, y = xy[:, 0:1], xy[:, 1:2]
    return 1 / (1 + x**2 + y**2 + (x - 1)**2 + (y - 1)**2)


def u_fun(xy):
    x, y = xy[:, 0:1], xy[:, 1:2]
    return np.sin(np.pi * x) * np.sin(np.pi * y)


def f_fun(xy):
    x, y = xy[:, 0:1], xy[:, 1:2]
    return 2 * np.pi**2 * np.sin(np.pi * x) * np.sin(np.pi * y) * a_fun(xy) + \
        2 * np.pi * ((2*x+1) * np.cos(np.pi * x) * np.sin(np.pi * y) + (2*y+1) * np.sin(np.pi * x) * np.cos(np.pi * y)) * a_fun(xy)**2


cashocs.set_log_level(cashocs.LogLevel.WARNING)

config = cashocs.load_config("scripts/config.ini")

mesh, subdomains, boundaries, dx, ds, dS = cashocs.regular_mesh(100)

V = FunctionSpace(mesh, "CG", 1)
coords = V.tabulate_dof_coordinates()

f = Function(V)
f.vector()[:] = f_fun(coords).reshape(-1)

# Create u sample with noise N(0, 0.1)
u_ref = Function(V)
u_ref.vector()[:] = u_fun(coords).reshape(-1)
u_sample = Function(V)
u_sample.vector()[:] = u_ref.vector()[:] + \
    np.random.normal(0, 0.1, u_ref.vector().size())

# Create a reference
a_ref = Function(V)
a_ref.vector()[:] = a_fun(coords).reshape(-1)

bcs = cashocs.create_dirichlet_bcs(V, a_ref, boundaries, [1, 2, 3, 4])

times = []
l2res = []

for _ in range(5):
    start = time.time()

    a = Function(V)
    v = Function(V)
    u = Function(V)

    # The weak form of the state equation,
    e = a*inner(grad(u), grad(v)) * dx - f * v * dx
    # and then define the cost functional as
    J = cashocs.IntegralFunctional(
        Constant(0.5) * (u - u_sample) * (u - u_sample) * dx
    )

    ocp = cashocs.OptimalControlProblem(
        e, bcs, J, a, u, v, 
        config=config)

    ocp.solve("bfgs", rtol=1e-6, atol=0, max_iter=100)
    end = time.time()
    times.append(end - start)
    l2res.append(l2_rel_error(
        np.concatenate((a.vector()[:], u.vector()[:])), 
        np.concatenate((a_ref.vector()[:], u_ref.vector()[:]))))

print("Average Time elapsed: %.2e" % (np.mean(times)))
print("L2-error = %.2e ± %.2e" % (np.mean(l2res), np.std(l2res)))
