import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.pdedata import PDEDataBase


class Heat2D_MS(PDEDataBase):
    """
    Prepare data for the 2D heat equation (Heat2d-MS).
    """
    def __init__(self) -> None:
        super().__init__()
        self.pde_coef = (1 / np.square(500 * np.pi), 1 / np.square(np.pi))
        self.init_coef = (20 * np.pi, np.pi)
        self.ref_sol = lambda x: np.sin(self.init_coef[0] * 
            x[:, 0:1]) * np.sin(self.init_coef[1] * x[:, 1:2]) * \
                np.exp(-(self.pde_coef[0] * self.init_coef[0]**2 + 
            self.pde_coef[1] * self.init_coef[1]**2) * x[:, 2:3])

    def init_problem(self):
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = BoxMesh(Point(bbox[0], bbox[2], bbox[4]), 
            Point(bbox[1], bbox[3], bbox[5]), 
            self.config["Grid Size (x-direction)"], 
            self.config["Grid Size (y-direction)"],
            self.config["Grid Size (t-direction)"])
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        # Define Dirichlet boundary (x = 0 or x = 1 or y = 0 or y = 1)
        def boundary_xy(x):
            return x[0] < DOLFIN_EPS or x[0] > 1.0 - DOLFIN_EPS or \
                x[1] < DOLFIN_EPS or x[1] > 1.0 - DOLFIN_EPS
        # Define Dirichlet boundary (t = 0)
        def boundary_t0(x):
            return x[2] < DOLFIN_EPS

        # Define boundary conditions
        bc_xy = DirichletBC(V, Constant(0.0), boundary_xy)
        u_t0 = Expression("sin(init_coef0 * x[0]) * sin(init_coef1 * x[1])", 
            degree=2, init_coef0=self.init_coef[0], 
            init_coef1=self.init_coef[1])
        bc_t0 = DirichletBC(V, u_t0, boundary_t0)

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)

        grad_u = as_vector((self.pde_coef[0] * u.dx(0), 
            self.pde_coef[1] * u.dx(1)))
        grad_v = as_vector((v.dx(0), v.dx(1)))
        u_t = u.dx(2)

        a = u_t * v * dx + inner(grad_u, grad_v) * dx

        # Compute matrix and right-hand side vector
        A = assemble(a)
        b = Vector(mesh.mpi_comm(), A.size(0))
        bc_xy.apply(A, b)
        bc_t0.apply(A, b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
