import numpy as np
import scipy.sparse as sp
from scipy.interpolate import NearestNDInterpolator
from dolfin import *

from src.tmpdedata import TimePDEData


class Wave2D_CG(TimePDEData):
    """
    Prepare data for the 2D wave equation (Wave2d-CG).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt("data/ref/wave_darcy.dat", comments="%")
        self.coef_data = np.loadtxt("data/ref/darcy_2d_coef_256.dat")

    def preprocess(self):
        dim_s = self.config["Spatial Dimension"]
        self.coef_data[:, :dim_s] = self.coef_data[:, :dim_s] * 2.0 - 1.0

    def init_problem(self):
        self.preprocess()

        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = RectangleMesh(Point(bbox[0], bbox[2]), 
            Point(bbox[1], bbox[3]), 
            self.config["Grid Size (x-direction)"], 
            self.config["Grid Size (y-direction)"])
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()
        
        # Interpolate the coefficient function into the mesh
        dim_s = self.config["Spatial Dimension"]
        c = Function(V)
        interp = NearestNDInterpolator(
            self.coef_data[:, :dim_s], self.coef_data[:, dim_s])
        coef_data = interp(self.coords[:, :dim_s])
        c.vector()[:] = coef_data

        dt = (bbox[-1] - bbox[-2]) / (
            self.config["Grid Size (t-direction)"] * 
            self.config["Sub-Time Intervals"])
        # Initial condition
        mu = (-0.5, 0)
        sigma = 0.3
        def initial_condition(x):
            return np.exp(-((x[:, 0:1] - mu[0])**2 + (x[:, 1:2] - mu[1])**2) / (2 * sigma**2))
        self.u0 = initial_condition(self.coords)
        self.u0_prev = self.u0

        # Variational problem
        u = TrialFunction(V)
        v = TestFunction(V)

        a = u*v*dx + dt*dt*c*inner(grad(u), grad(v))*dx
        k = u*v*dx

        # Compute matrix and right-hand side vector
        A = assemble(a)
        # Mass matrix for computing right-hand side vector
        # at each time step
        # K @ (2*u1-u0) = b
        K = assemble(k)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = np.array(0) # Dummy b
        csr = as_backend_type(K).mat().getValuesCSR()[::-1]
        self.K = sp.csr_matrix(csr)
