import torch
import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.nlpdedata import NonLinearPDEData
from utils.fenics_utils import compute_dof_map_ns2d
from utils.fenics_utils import generate_complex_mesh2d


class NS2D_CG(NonLinearPDEData):
    """
    Prepare data for the 2D Navier-Stokes equation (NS2D_CG).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt("data/ref/ns_0_obstacle.dat", comments="%")

    def init_problem(self):
        a = 4
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        eps = 1e-5
        voids = {"rec": [[bbox[0] - eps, bbox[1] / 2, bbox[3] / 2, bbox[3] + eps]]}
        mesh = generate_complex_mesh2d(bbox, voids, self.config["Grid Size (mesh length)"])
        # Build the function space (Taylor-Hood)
        P2 = VectorElement("P", mesh.ufl_cell(), 2)
        P1 = FiniteElement("P", mesh.ufl_cell(), 1)
        TH = MixedElement([P2, P1])
        W = FunctionSpace(mesh, TH)
        self.W = W

        # Get the dof coordinates
        W_u = W.sub(0).collapse()
        W_p = W.sub(1).collapse()
        self.coords = np.concatenate((
            W_u.tabulate_dof_coordinates()[::2],
            W_p.tabulate_dof_coordinates()
        ))
        
        # Define Dirichlet boundary (x = 0)
        def boundary_in(x, on_boundary):
            return on_boundary and near(x[0], bbox[0])
        # Define Dirichlet boundary (x = 1)
        def boundary_out(x, on_boundary):
            return on_boundary and near(x[0], bbox[1])
        # Define Dirichlet boundary (other)
        def boundary_other(x, on_boundary):
            return on_boundary and not (
                boundary_in(x, on_boundary) or boundary_out(x, on_boundary))
        
        # Define boundary conditions
        u_in = Expression(("a*x[1]*(1 - x[1])", "0.0"), degree=2, a=a)
        self.bc_in = DirichletBC(W.sub(0), u_in , boundary_in)
        self.bc_in_zero = DirichletBC(W.sub(0), Constant((0.0, 0.0)), boundary_in)
        self.bc_other = DirichletBC(W.sub(0), Constant((0.0, 0.0)), boundary_other)
        self.bc_other_zero = self.bc_other
        self.bc_out = DirichletBC(W.sub(1), Constant(0.0), boundary_out)
        self.bc_out_zero = self.bc_out

        # Compute dof map
        self.select_p_map, self.reorder_p_map, \
            self.select_u0_map, self.reorder_u0_map, \
            self.select_u1_map, self.reorder_u1_map = compute_dof_map_ns2d(W)

    def rearrange_u_data(self, u_data):
        len_u0 = len(self.reorder_u0_map)
        len_u1 = len(self.reorder_u1_map)
        len_p = len(self.reorder_p_map)
        tot_len = len_u0 + len_u1 + len_p
        u0_data = u_data[:-len_p, 0:1]
        u1_data = u_data[:-len_p, 1:2]
        p_data = u_data[-len_p:, 2:3]

        if isinstance(u_data, np.ndarray):
            u_data = np.zeros((tot_len, 1))
        elif isinstance(u_data, torch.Tensor):
            u_data = torch.zeros((tot_len, 1), 
                device=u_data.device, dtype=u_data.dtype)
        else:
            raise ValueError("Unknown data type")
        u_data[self.select_u0_map] = u0_data[self.reorder_u0_map]
        u_data[self.select_u1_map] = u1_data[self.reorder_u1_map]
        u_data[self.select_p_map] = p_data[self.reorder_p_map]
        return u_data

    def init_newton_system(self):
        # Define variational problem
        w = Function(self.W)
        w.vector()[:] = self.u_data
        v, q = TestFunctions(self.W)

        # Apply boundary conditions to u
        self.bc_in.apply(w.vector())
        self.bc_other.apply(w.vector())
        self.bc_out.apply(w.vector())

        # Define Newton system
        nu = 1/100
        u, p = split(w)
        F = nu*inner(grad(u), grad(v))*dx + dot(dot(grad(u), u), v)*dx \
            - p*div(v)*dx - q*div(u)*dx
        dF = derivative(F, w)

        # Compute matrix and right-hand side vector
        A = assemble(dF)
        b = assemble(-F)

        # Zero the boundary locations
        self.bc_in_zero.apply(A)
        self.bc_other_zero.apply(A)
        self.bc_out_zero.apply(A)
        self.bc_in_zero.apply(b)
        self.bc_other_zero.apply(b)
        self.bc_out_zero.apply(b)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
        self.b += self.A @ w.vector()[:].reshape(-1, 1)
