import math
import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.pdedata import PDEDataBase


class Wave2D_MS(PDEDataBase):
    """
    Prepare data for the 2D wave equation (Wave2d-MS).
    """
    def __init__(self) -> None:
        super().__init__()
        m1 = 1; m2 = 3; n1 = 1; n2 = 2; p1 = 1; p2 = 1
        c1 = c2 = 1
        self.ref_sol = lambda x: c1 * np.sin(m1 * np.pi * x[:, 0:1]) * \
                np.sinh(n1 * np.pi * x[:, 1:2]) * np.cos(p1 * np.pi * x[:, 2:3]) + \
            c2 * np.sinh(m2 * np.pi * x[:, 0:1]) * \
                np.sin(n2 * np.pi * x[:, 1:2]) * np.cos(p2 * np.pi * x[:, 2:3])

    def init_problem(self):
        a = math.sqrt(2)
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = BoxMesh(Point(bbox[0], bbox[2], bbox[4]), 
            Point(bbox[1], bbox[3], bbox[5]), 
            self.config["Grid Size (x-direction)"], 
            self.config["Grid Size (y-direction)"],
            self.config["Grid Size (t-direction)"])
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        # Define Dirichlet boundary
        # x = 0 or x = 1 or y = 0 or y = 1 or t = 0
        def boundary(x):
            return x[0] < DOLFIN_EPS or x[0] > 1.0 - DOLFIN_EPS or \
                x[1] < DOLFIN_EPS or x[1] > 1.0 - DOLFIN_EPS or \
                x[2] < DOLFIN_EPS

        # Define boundary condition
        u_b = Function(V)
        u_b.vector()[:] = self.ref_sol(self.coords).reshape(-1)
        bc = DirichletBC(V, u_b, boundary)

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)

        grad_u = as_vector((u.dx(0), (a**2) * u.dx(1)))
        u_t = u.dx(2)
        grad_v = as_vector((v.dx(0), v.dx(1)))
        v_t = v.dx(2)

        a = inner(grad_u, grad_v)*dx + u_t*v_t*dx

        # Compute matrix and right-hand side vector
        A = assemble(a)
        b = Vector(mesh.mpi_comm(), A.size(0))
        bc.apply(A, b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
