import gmsh
import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.pdedata import PDEDataBase
from utils.fenics_utils import generate_complex_mesh2d


class Poisson2D_C(PDEDataBase):
    def __init__(self):
        super().__init__()
        self.ref_data = np.loadtxt(
            "data/ref/poisson_classic.dat", comments="%")

    def init_problem(self):
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        voids = {
            'circ': [[0.3, 0.3, 0.1], 
                    [-0.3, 0.3, 0.1], 
                    [0.3, -0.3, 0.1], 
                    [-0.3, -0.3, 0.1]]
        }
        mesh_length = self.config["Grid Size (mesh length)"]
        mesh = generate_complex_mesh2d(bbox, voids, mesh_length)
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        # Define Dirichlet boundary for rectangle
        def rec_boundary(x):
            return near(x[0], bbox[0]) or \
                near(x[0], bbox[1]) or near(x[1], bbox[2]) or \
                near(x[1], bbox[3])

        # Define Dirichlet boundary for circles
        def circ_boundary(x):
            return any([near((x[0] - c[0]) ** 2 + 
            (x[1] - c[1]) ** 2, c[2] ** 2, eps=1e-3) 
            for c in voids['circ']])

        # Define boundary conditions
        u_rec = Constant(1.0)
        bc_rec = DirichletBC(V, u_rec, rec_boundary)
        u_circ = Constant(0.0)
        bc_circ = DirichletBC(V, u_circ, circ_boundary)

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        a = inner(grad(u), grad(v)) * dx

        # Compute matrix and right-hand side vector
        A = assemble(a)
        b = Vector(mesh.mpi_comm(), A.size(0))
        bc_rec.apply(A, b)
        bc_circ.apply(A, b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
