import torch
import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.nlpdedata import NonLinearPDEData
from utils.fenics_utils import compute_dof_map_ns2d


class NS2D_C(NonLinearPDEData):
    """
    Prepare data for the 2D Navier-Stokes equation (NS2D_C).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt("data/ref/lid_driven_a4.dat", comments="%")

    def init_problem(self):
        a = 4
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = RectangleMesh(Point(bbox[0], bbox[2]), 
            Point(bbox[1], bbox[3]), 
            self.config["Grid Size (x-direction)"], 
            self.config["Grid Size (y-direction)"])
        # Build the function space (Taylor-Hood)
        P2 = VectorElement("P", mesh.ufl_cell(), 2)
        P1 = FiniteElement("P", mesh.ufl_cell(), 1)
        TH = MixedElement([P2, P1])
        W = FunctionSpace(mesh, TH)
        self.W = W

        # Get the dof coordinates
        W_u = W.sub(0).collapse()
        W_p = W.sub(1).collapse()
        self.coords = np.concatenate((
            W_u.tabulate_dof_coordinates()[::2],
            W_p.tabulate_dof_coordinates()
        ))
        
        # Define Dirichlet boundary (y = 1)
        def boundary_top(x):
            return x[1] > (1.0 - DOLFIN_EPS)
        # Define Dirichlet boundary (other)
        def boundary_other(x):
            return x[0] < DOLFIN_EPS or x[0] > (1.0 - DOLFIN_EPS) or x[1] < DOLFIN_EPS
        
        # Define boundary conditions
        u_top = Expression(("a*x[0]*(1-x[0])", "0.0"), degree=2, a=a)
        self.bc_top = DirichletBC(W.sub(0), u_top , boundary_top)
        self.bc_top_zero = DirichletBC(W.sub(0), Constant((0.0, 0.0)), boundary_top)
        self.bc_other = DirichletBC(W.sub(0), Constant((0.0, 0.0)), boundary_other)
        self.bc_other_zero = self.bc_other
        self.bc_point = DirichletBC(W.sub(1), Constant(0.0), 
            "x[0] < DOLFIN_EPS && x[1] < DOLFIN_EPS", "pointwise")
        self.bc_point_zero = self.bc_point

        # Compute dof map
        self.select_p_map, self.reorder_p_map, \
            self.select_u0_map, self.reorder_u0_map, \
            self.select_u1_map, self.reorder_u1_map = compute_dof_map_ns2d(W)

    def rearrange_u_data(self, u_data):
        len_u0 = len(self.reorder_u0_map)
        len_u1 = len(self.reorder_u1_map)
        len_p = len(self.reorder_p_map)
        tot_len = len_u0 + len_u1 + len_p
        u0_data = u_data[:-len_p, 0:1]
        u1_data = u_data[:-len_p, 1:2]
        p_data = u_data[-len_p:, 2:3]

        if isinstance(u_data, np.ndarray):
            u_data = np.zeros((tot_len, 1))
        elif isinstance(u_data, torch.Tensor):
            u_data = torch.zeros((tot_len, 1), 
                device=u_data.device, dtype=u_data.dtype)
        else:
            raise ValueError("Unknown data type")
        u_data[self.select_u0_map] = u0_data[self.reorder_u0_map]
        u_data[self.select_u1_map] = u1_data[self.reorder_u1_map]
        u_data[self.select_p_map] = p_data[self.reorder_p_map]
        return u_data

    def init_newton_system(self):
        # Define variational problem
        w = Function(self.W)
        w.vector()[:] = self.u_data
        v, q = TestFunctions(self.W)

        # Apply boundary conditions to u
        self.bc_top.apply(w.vector())
        self.bc_other.apply(w.vector())
        self.bc_point.apply(w.vector())

        # Define Newton system
        nu = 1/100
        u, p = split(w)
        F = nu*inner(grad(u), grad(v))*dx + dot(dot(grad(u), u), v)*dx \
            - p*div(v)*dx - q*div(u)*dx
        dF = derivative(F, w)

        # Compute matrix and right-hand side vector
        A = assemble(dF)
        b = assemble(-F)

        # Zero the boundary locations
        self.bc_top_zero.apply(A)
        self.bc_other_zero.apply(A)
        self.bc_point_zero.apply(A)
        self.bc_top_zero.apply(b)
        self.bc_other_zero.apply(b)
        self.bc_point_zero.apply(b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
        self.b += self.A @ w.vector()[:].reshape(-1, 1)
