import numpy as np
import scipy.sparse as sp
from scipy.interpolate import NearestNDInterpolator
from dolfin import *

from src.pdedata import PDEDataBase
from utils.utils import decompress_time_data


class Heat2D_VC(PDEDataBase):
    """
    Prepare data for the 2D heat equation (Heat2d-VC).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt("data/ref/heat_darcy.dat", comments="%")
        self.coef_data = np.loadtxt("data/ref/heat_2d_coef_256.dat")

    def preprocess(self):
        dim_s = self.config["Spatial Dimension"]
        self.ref_data = decompress_time_data(
            self.ref_data, dim_s, 
            self.config["Output Dimension"], 
            self.config["Bounding Box"][dim_s*2:])

    def init_problem(self):
        self.preprocess()

        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = BoxMesh(Point(bbox[0], bbox[2], bbox[4]), 
            Point(bbox[1], bbox[3], bbox[5]), 
            self.config["Grid Size (x-direction)"], 
            self.config["Grid Size (y-direction)"],
            self.config["Grid Size (t-direction)"])
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        # Interpolate the coefficient function into the mesh
        dim_s = self.config["Spatial Dimension"]
        a = Function(V)
        interp = NearestNDInterpolator(
            self.coef_data[:, :dim_s], self.coef_data[:, dim_s])
        coef_data = interp(self.coords[:, :dim_s])
        a.vector()[:] = coef_data

        # Define Dirichlet boundary (x = 0 or x = 1 or y = 0 or y = 1)
        def boundary_xy(x):
            return x[0] < DOLFIN_EPS or x[0] > 1.0 - DOLFIN_EPS or \
                x[1] < DOLFIN_EPS or x[1] > 1.0 - DOLFIN_EPS
        # Define Dirichlet boundary (t = 0)
        def boundary_t0(x):
            return x[2] < DOLFIN_EPS

        # Define boundary conditions
        bc_xy = DirichletBC(V, Constant(0.0), boundary_xy)
        bc_t0 = DirichletBC(V, Constant(0.0), boundary_t0)

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        A = 200
        m = (1, 5, 1)

        # Define the function f
        f = Expression(
            "A * sin(m0 * pi * x[0]) * sin(m1 * pi * x[1]) * sin(m2 * pi * x[2])",
            degree=2, A=A, m0=m[0], m1=m[1], m2=m[2],
        )

        # Define the bilinear form and linear form
        grad_u = as_vector((u.dx(0), u.dx(1)))
        grad_v = as_vector((v.dx(0), v.dx(1)))
        u_t = u.dx(2)

        a = u_t * v * dx + a * inner(grad_u, grad_v) * dx
        L = f * v * dx

        # Compute matrix and right-hand side vector
        A = assemble(a)
        b = assemble(L)
        bc_xy.apply(A, b)
        bc_t0.apply(A, b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
