import gmsh
import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.pdedata import PDEDataBase
from utils.fenics_utils import generate_complex_mesh3d


class Poisson3D_CG(PDEDataBase):
    def __init__(self):
        super().__init__()
        self.ref_data = np.loadtxt(
            "data/ref/poisson_3d.dat", comments="%")

    def init_problem(self):
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        voids = {
            'sphere': [[0.4, 0.3, 0.6, 0.2], 
                    [0.6, 0.7, 0.6, 0.2], 
                    [0.2, 0.8, 0.7, 0.1], 
                    [0.6, 0.2, 0.3, 0.1]]
        }
        mesh_length = self.config["Grid Size (mesh length)"]
        mesh = generate_complex_mesh3d(bbox, voids, mesh_length)
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        A = (20, 100)
        m = (1, 10, 5)
        k_values = (8, 10)
        mu_values = (1, 1)
        interface_z = 0.5

        # Define the function f
        f_expr = Expression(
            "A0 * exp(sin(m0 * pi * x[0]) + sin(m1 * pi * x[1]) + sin(m2 * pi * x[2])) * (pow(x[0], 2) + pow(x[1], 2) + pow(x[2], 2) - 1) / (pow(x[0], 2) + pow(x[1], 2) + pow(x[2], 2) + 1)"
            "+ A1 * sin(m0 * pi * x[0]) * sin(m1 * pi * x[1]) * sin(m2 * pi * x[2])",
            degree=2,
            A0=A[0],
            A1=A[1],
            m0=m[0],
            m1=m[1],
            m2=m[2],
        )
        f = interpolate(f_expr, V)

        # Define mu and k as Functions and interpolate the values
        mu = Function(V)
        k = Function(V)
        mu.vector()[:] = [mu_values[0] if z < interface_z else mu_values[1] for x, y, z in self.coords]
        k.vector()[:] = [k_values[0] if z < interface_z else k_values[1] for x, y, z in self.coords]

        # Define the bilinear form and linear form
        a = mu * inner(grad(u), grad(v)) * dx + k**2 * u * v * dx
        L = f * v * dx

        # Compute matrix and right-hand side vector
        A = assemble(a)
        b = assemble(L)

        # Convert PETSc matrix to scipy sparse array
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
