import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.nlpdedata import NonLinearPDEData
from utils.utils import decompress_time_data


class Burgers1D_C(NonLinearPDEData):
    """
    Prepare data for the 1D Burgers' equation (Burgers1d-C).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt("data/ref/burgers1d.dat", comments="%")

    def preprocess(self):
        dim_s = self.config["Spatial Dimension"]
        self.ref_data = decompress_time_data(
            self.ref_data, dim_s, 
            self.config["Output Dimension"], 
            self.config["Bounding Box"][dim_s*2:])

    def init_problem(self):
        self.preprocess()

        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = RectangleMesh(Point(bbox[0], bbox[2]), 
            Point(bbox[1], bbox[3]), 
            self.config["Grid Size (x-direction)"], 
            self.config["Grid Size (t-direction)"])
        V = FunctionSpace(mesh, "Lagrange", 1)
        self.V = V

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()
        
        # Define Dirichlet boundary (x = -1 or x = 1)
        def boundary_lr(x):
            return x[0] < -1.0 + DOLFIN_EPS or \
                x[0] > 1.0 - DOLFIN_EPS
        # Define Dirichlet boundary (t = 0)
        def boundary_t0(x):
            return x[1] < DOLFIN_EPS
        
        # Define boundary conditions
        u_lr = Constant(0.0)
        self.bc_lr = DirichletBC(V, u_lr, boundary_lr)
        self.bc_lr_zero = self.bc_lr
        u_t0 = Expression("-sin(pi * x[0])", degree=2)
        self.bc_t0 = DirichletBC(V, u_t0, boundary_t0)
        self.bc_t0_zero = DirichletBC(V, Constant(0.0), boundary_t0)
    
    def init_newton_system(self):
        # Define variational problem
        u = Function(self.V)
        u.vector()[:] = self.u_data
        v = TestFunction(self.V)

        # Apply boundary conditions to u
        self.bc_lr.apply(u.vector())
        self.bc_t0.apply(u.vector())

        # Define Newton system
        nu = 0.01 / np.pi
        F = u.dx(1)*v*dx + u*u.dx(0)*v*dx + nu*u.dx(0)*v.dx(0)*dx
        dF = derivative(F, u)

        # Compute matrix and right-hand side vector
        A = assemble(dF)
        b = assemble(-F)

        # Zero the boundary locations
        self.bc_lr_zero.apply(A)
        self.bc_t0_zero.apply(A)
        self.bc_lr_zero.apply(b)
        self.bc_t0_zero.apply(b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
        self.b += self.A @ u.vector()[:].reshape(-1, 1)
