import gmsh
import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.pdedata import PDEDataBase
from utils.fenics_utils import generate_complex_mesh2d


class Poisson2D_CG(PDEDataBase):
    def __init__(self):
        super().__init__()
        self.ref_data = np.loadtxt(
            "data/ref/poisson_boltzmann2d.dat", comments="%")

    def init_problem(self):
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        voids = {
            'circ': [[0.5, 0.5, 0.2], 
                    [0.4, -0.4, 0.4], 
                    [-0.2, -0.7, 0.1], 
                    [-0.6, 0.5, 0.3]]
        }
        mesh_length = self.config["Grid Size (mesh length)"]
        mesh = generate_complex_mesh2d(bbox, voids, mesh_length)
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        # Define Dirichlet boundary for rectangle
        def rec_boundary(x):
            return near(x[0], bbox[0]) or \
                near(x[0], bbox[1]) or \
                near(x[1], bbox[2]) or near(x[1], bbox[3])

        # Define Dirichlet boundary for circles
        def circ_boundary(x):
            return any([near((x[0] - c[0]) ** 2 + 
                (x[1] - c[1]) ** 2, c[2] ** 2, eps=1e-3) 
                for c in voids['circ']])

        # Define boundary conditions
        u_rec = Constant(0.2)
        bc_rec = DirichletBC(V, u_rec, rec_boundary)
        u_circ = Constant(1.0)
        bc_circ = DirichletBC(V, u_circ, circ_boundary)

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        k = 8
        A = 10
        mu = (1, 4)

        # Define the function f
        f_expr = Expression(
            "A * (pow(mu0, 2) + pow(x[0], 2) + pow(mu1, 2) + pow(x[1], 2))"
            "* sin(mu0 * pi * x[0]) * sin(mu1 * pi * x[1])",
            degree=2,
            A=A,
            mu0=mu[0],
            mu1=mu[1],
        )
        f = interpolate(f_expr, V)

        # Define the bilinear form and linear form
        a = (inner(grad(u), grad(v)) + k**2 * u * v) * dx
        L = f * v * dx

        # Compute matrix and right-hand side vector
        A = assemble(a)
        b = assemble(L)
        bc_rec.apply(A, b)
        bc_circ.apply(A, b)

        # Convert PETSc matrix to scipy sparse array
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
