import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.pdedata import PDEDataBase
from utils.utils import decompress_time_data
from utils.fenics_utils import generate_complex_mesh3d


class Heat2D_CG(PDEDataBase):
    """
    Prepare data for the 2D heat equation (Heat2d-CG).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt("data/ref/heat_complex.dat", comments="%")

    def preprocess(self):
        dim_s = self.config["Spatial Dimension"]
        self.ref_data = decompress_time_data(
            self.ref_data, dim_s, 
            self.config["Output Dimension"], 
            self.config["Bounding Box"][dim_s*2:])

    def init_problem(self):
        self.preprocess()

        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        big_disk_center = [
            (-4, -3), (4, -3), (-4, 3), (4, 3), 
            (-4, -9), (4, -9), (-4, 9), (4, 9), 
            (0, 0), (0, 6), (0, -6)]
        big_disk_radius = 1
        small_disk_center = [
            (-3.2, -6), (-3.2, 6), (3.2, -6), (3.2, 6), 
            (-3.2, 0), (3.2, 0)]
        small_radius = 0.4
        voids = {"cylinder": [
            [c[0], c[1], 0, bbox[-1], big_disk_radius] 
                for c in big_disk_center] + [
            [c[0], c[1], 0, bbox[-1], small_radius]
                for c in small_disk_center]
        }
        mesh_length = self.config["Grid Size (mesh length)"]
        mesh = generate_complex_mesh3d(bbox, voids, mesh_length)
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        # Define MeshFunction for boundary domains
        boundary_markers = MeshFunction('size_t', mesh, mesh.topology().dim()-1)

        # Mark big disks with boundary marker 1
        class BigDiskBoundary(SubDomain):
            def inside(self, x, on_boundary):
                return on_boundary and (any([
                    np.abs((x[0]-cx)**2 + (x[1]-cy)**2 - big_disk_radius**2) < mesh_length
                    for cx, cy in big_disk_center]))

        big_disk_boundary = BigDiskBoundary()
        big_disk_boundary.mark(boundary_markers, 1)

        # Mark small disks with boundary marker 2
        class SmallDiskBoundary(SubDomain):
            def inside(self, x, on_boundary):
                return on_boundary and (any([
                    np.abs((x[0]-cx)**2 + (x[1]-cy)**2 - small_radius**2) < mesh_length
                    for cx, cy in small_disk_center]))

        small_disk_boundary = SmallDiskBoundary()
        small_disk_boundary.mark(boundary_markers, 2)

        # Mark the outer boundary with boundary marker 3
        class OuterBoundary(SubDomain):
            def inside(self, x, on_boundary):
                return on_boundary and (
                    np.abs(x[0] - bbox[0]) < DOLFIN_EPS or 
                    np.abs(x[0] - bbox[1]) < DOLFIN_EPS or
                    np.abs(x[1] - bbox[2]) < DOLFIN_EPS or
                    np.abs(x[1] - bbox[3]) < DOLFIN_EPS)

        outer_boundary = OuterBoundary()
        outer_boundary.mark(boundary_markers, 3)

        # Define measure using the boundary markers
        ds = Measure('ds', domain=mesh, subdomain_data=boundary_markers)

        # Define Dirichlet boundary (t = 0)
        def boundary_t0(x):
            return x[2] < DOLFIN_EPS

        # Define boundary conditions
        bc_t0 = DirichletBC(V, Constant(0.0), boundary_t0)

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)

        # Define the bilinear form and linear form
        grad_u = as_vector((u.dx(0), u.dx(1)))
        grad_v = as_vector((v.dx(0), v.dx(1)))
        u_t = u.dx(2)

        F = u_t * v * dx + inner(grad_u, grad_v) * dx - \
            (0.1 - u) * v * ds(3) - (1 - u) * v * ds(2) - \
            (5.0 - u) * v * ds(1)
        a = lhs(F)
        L = rhs(F)

        # Compute matrix and right-hand side vector
        A = assemble(a)
        b = assemble(L)
        bc_t0.apply(A, b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
