import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.pdedata import PDEDataBase


class Wave1D_C(PDEDataBase):
    """
    Prepare data for the 1D wave equation (Wave1d-C).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_sol = lambda x: np.sin(np.pi * x[:, 0:1]) * np.cos(2 * np.pi * x[:, 1:2]) + \
            0.5 * np.sin(4 * np.pi * x[:, 0:1]) * np.cos(8 * np.pi * x[:, 1:2])
    
    def init_problem(self):
        # Create mesh and define function space
        mesh = UnitSquareMesh(self.config["Grid Size (x-direction)"], 
            self.config["Grid Size (t-direction)"])
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()
        
        # Define Dirichlet boundary (x = 0 or x = 1)
        def boundary_lr(x):
            return x[0] < DOLFIN_EPS or x[0] > 1.0 - DOLFIN_EPS
        # Define Dirichlet boundary (t = 0)
        def boundary_t0(x):
            return x[1] < DOLFIN_EPS
        
        # Define boundary conditions
        u_lr = Constant(0.0)
        bc_lr = DirichletBC(V, u_lr, boundary_lr)
        u_t0 = Expression(
            "sin(pi * x[0]) + 0.5 * sin(4 * pi * x[0])", degree=2)
        bc_t0 = DirichletBC(V, u_t0, boundary_t0)

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        # Get derivative vector: (-4u_x, u_t)
        deriv_vec = as_vector((-4*u.dx(0), u.dx(1)))
        a = inner(deriv_vec, grad(v))*dx

        # Compute matrix and right-hand side vector
        A = assemble(a)
        b = Vector(mesh.mpi_comm(), A.size(0))
        bc_lr.apply(A, b)
        bc_t0.apply(A, b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
