import numpy as np
import torch
import scipy.sparse as sp
from dolfin import *

from src.tmpdedata import TimePDEData


class Heat2D_LT(TimePDEData):
    """
    Prepare data for the 2D heat equation (Heat2d-LT).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt("data/ref/heat_longtime.dat", comments="%")

    def init_problem(self):
        init_coef = (4 * np.pi, 3 * np.pi)

        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = RectangleMesh(Point(bbox[0], bbox[2]), 
            Point(bbox[1], bbox[3]), 
            self.config["Grid Size (x-direction)"], 
            self.config["Grid Size (y-direction)"])
        V = FunctionSpace(mesh, "Lagrange", 1)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()

        dt = (bbox[-1] - bbox[-2]) / (
            self.config["Grid Size (t-direction)"] * 
            self.config["Sub-Time Intervals"])
        # Initial condition
        def initial_condition(x):
            return np.sin(init_coef[0] * x[:, 0:1]) * \
                np.sin(init_coef[1] * x[:, 1:2])
        self.u0 = initial_condition(self.coords)

        # Define Dirichlet boundary (x = 0 or x = 1 or y = 0 or y = 1)
        def boundary_all(x, on_boundary):
            return on_boundary

        # Define boundary conditions
        bc = DirichletBC(V, Constant(0.0), boundary_all)

        # Variational problem
        u = TrialFunction(V)
        v = TestFunction(V)

        a = u*v*dx + dt*0.001*inner(grad(u), grad(v))*dx
        k = u*v*dx

        # Compute matrix and right-hand side vector
        A = assemble(a)
        # Mass matrix for computing right-hand side vector
        # at each time step
        K = assemble(k)

        b = Vector(mesh.mpi_comm(), A.size(0))

        bc.apply(A)
        bc.apply(b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
        csr = as_backend_type(K).mat().getValuesCSR()[::-1]
        self.K = sp.csr_matrix(csr)

    def compute_tm_rhs(self, X, t, dt, u):
        t = t.reshape(-1, 1, 1)
        X = X.unsqueeze(0).expand(t.shape[0], -1, -1)
        k = 1; m1 = 4; m2 = 2
        f = (1 + 2 * torch.sin(torch.pi * t / 4)) * \
            torch.sin(m1 * torch.pi * X[..., 0:1]) * \
            torch.sin(m2 * torch.pi * X[..., 1:2])
        f = dt * 5 * torch.sin(k * u**2) * f
        return f
