import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.pdedata import PDEDataBase


class Poisson2D_MS(PDEDataBase):
    def __init__(self):
        super().__init__()
        self.ref_data = np.loadtxt(
            "data/ref/poisson_manyarea.dat", comments="%")
        self.a_cof = np.loadtxt("data/ref/poisson_a_coef.dat")
        self.f_cof = np.loadtxt(
            "data/ref/poisson_f_coef.dat").reshape(5, 5, 2, 2)

    def init_problem(self):
        # Create mesh and define function space
        self.bbox = self.config["Bounding Box"]
        mesh = RectangleMesh(
            Point(self.bbox[0], self.bbox[2]),
            Point(self.bbox[1], self.bbox[3]),
            self.config["Grid Size (x-direction)"],
            self.config["Grid Size (y-direction)"],
        )
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        # Define a and f as Functions and interpolate the values
        a = Function(V)
        f = Function(V)
        a.vector()[:] = [self.compute_a(x) for x in self.coords]
        f.vector()[:] = [self.compute_f(x) for x in self.coords]

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        a_form = a * inner(grad(u), grad(v)) * dx
        L = f * v * dx

        # Define Robin boundary condition
        a_form += a * u * v * ds

        # Compute matrix and right-hand side vector
        A = assemble(a_form)
        b = assemble(L)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)

    def compute_a(self, x):
        # Compute the block indices
        block_size = np.array([(self.bbox[1] - self.bbox[0] + 2e-5) / 5, (self.bbox[3] - self.bbox[2] + 2e-5) / 5])
        reduced_x = x - np.array(self.bbox[::2]) + 1e-5
        dom = np.floor(reduced_x / block_size).astype("int32")

        return self.a_cof[dom[0], dom[1]]

    def compute_f(self, x):
        # Compute the block indices and the position within the block
        block_size = np.array([(self.bbox[1] - self.bbox[0] + 2e-5) / 5, (self.bbox[3] - self.bbox[2] + 2e-5) / 5])
        reduced_x = x - np.array(self.bbox[::2]) + 1e-5
        dom = np.floor(reduced_x / block_size).astype("int32")
        res = reduced_x - dom * block_size

        # Compute f based on the position within the block
        coef = self.f_cof[dom[0], dom[1]]
        ans = coef[0, 0]
        for i in range(coef.shape[0]):
            for j in range(coef.shape[1]):
                tmp = np.sin(np.pi * np.array((i, j)) * (res / block_size))
                ans += coef[i, j] * tmp[0] * tmp[1]
        return ans
