import torch
import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.tmpdedata import TimePDEData
from utils.fenics_utils import compute_dof_map_ns2d


class NS2D_LT(TimePDEData):
    """
    Prepare data for the 2D Navier-Stokes equation (NS2D_LT).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt("data/ref/ns_long.dat", comments="%")

    def init_problem(self):
        # Some constants
        bbox = self.config["Bounding Box"]
        self.dt = dt = (bbox[-1] - bbox[-2]) / (
            self.config["Grid Size (t-direction)"] * 
            self.config["Sub-Time Intervals"])
        time_start, time_end = bbox[-2], bbox[-1]
        n_intervals = self.config["Sub-Time Intervals"]
        interval_len = (time_end - time_start) / n_intervals
        # Time points for the first sub-interval
        ts = np.linspace(
            time_start + dt,
            time_start + interval_len,
            self.config["Grid Size (t-direction)"]
        )

        # Create mesh and define function space
        mesh = RectangleMesh(
            Point(bbox[0], bbox[2]), Point(bbox[1], bbox[3]),
            self.config["Grid Size (x-direction)"],
            self.config["Grid Size (y-direction)"])
        # Build the function space (Taylor-Hood)
        P2 = VectorElement("P", mesh.ufl_cell(), 2)
        P1 = FiniteElement("P", mesh.ufl_cell(), 1)
        TH = MixedElement([P2, P1])
        W = FunctionSpace(mesh, TH)
        self.W = W

        # Get the dof coordinates
        W_u = W.sub(0).collapse()
        W_p = W.sub(1).collapse()
        self.coords = np.concatenate((
            W_u.tabulate_dof_coordinates()[::2],
            W_p.tabulate_dof_coordinates()
        ))
        
        # Initial condition
        self.u0 = np.zeros((len(self.coords), self.config["Output Dimension"]))

        # Compute dof map
        self.select_p_map, self.reorder_p_map, \
            self.select_u0_map, self.reorder_u0_map, \
            self.select_u1_map, self.reorder_u1_map = compute_dof_map_ns2d(W)
        
        # Init the linear system
        # Define Dirichlet boundary (x = 0)
        def boundary_in(x, on_boundary):
            return on_boundary and near(x[0], bbox[0])
        # Define Dirichlet boundary (x = 1)
        def boundary_out(x, on_boundary):
            return on_boundary and near(x[0], bbox[1])
        # Define Dirichlet boundary (other)
        def boundary_other(x, on_boundary):
            return on_boundary and \
                (near(x[1], bbox[2]) or near(x[1], bbox[3]))

        # Define boundary conditions
        COEF_A1 = 1
        COEF_A2 = 1
        COEF_A3 = 1
        u_in = Expression((
            "sin(pi * x[1]) * (A1 * sin(pi * t) + A2 * sin(3 * pi * t) + A3 * sin(5 * pi * t))", 
            "0.0"), degree=2, A1=COEF_A1, A2=COEF_A2, A3=COEF_A3, t=ts[0])
        bc_in = DirichletBC(W.sub(0), u_in, boundary_in)
        bc_other = DirichletBC(W.sub(0), Constant((0.0, 0.0)), boundary_other)
        bc_out = DirichletBC(W.sub(1), Constant(0.0), boundary_out)
        self.bcs = [bc_in, bc_other, bc_out]
        self.boundary_in = boundary_in

        # Define variational problem
        nu = 1/100
        v, q = TestFunctions(W)
        u, p = TrialFunctions(W)
        a = inner(u,v)*dx + \
            dt*nu*inner(grad(u), grad(v))*dx \
            - dt*p*div(v)*dx - dt*q*div(u)*dx
        k = inner(u,v)*dx
        
        # Compute matrix and right-hand side vector
        A = assemble(a)
        # Mass matrix for computing right-hand side vector
        # at each time step
        K = assemble(k)

        # Apply boundary conditions
        for bc in self.bcs:
            bc.apply(A)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        csr = as_backend_type(K).mat().getValuesCSR()[::-1]
        self.K = sp.csr_matrix(csr)

        # Compute the right-hand side vector
        self.init_rhs(ts)
    
    def init_rhs(self, ts):
        self.b = self.get_rhs_at_tm_step(ts[0])
        for t in ts[1:]:
            self.b = np.concatenate((
                self.b, self.get_rhs_at_tm_step(t)
            ), axis=0)

    def get_rhs_at_tm_step(self, t):
        # Boundary condition
        COEF_A1 = 1
        COEF_A2 = 1
        COEF_A3 = 1
        u_in = Expression((
            "sin(pi * x[1]) * (A1 * sin(pi * t) + A2 * sin(3 * pi * t) + A3 * sin(5 * pi * t))", 
            "0.0"), degree=2, A1=COEF_A1, A2=COEF_A2, A3=COEF_A3, t=t)
        bc_in = DirichletBC(self.W.sub(0), u_in, self.boundary_in)
        self.bcs[0] = bc_in

        f = Expression((
            "0.0",
            "-sin(pi * x[0]) * sin(pi * x[1]) * sin(pi * t)"), 
            degree=2, t=t)
        v, q = TestFunctions(self.W)
        L = self.dt*dot(f,v)*dx
        b = assemble(L)

        # Apply boundary conditions
        for bc in self.bcs:
            bc.apply(b)

        return b.get_local().reshape(1, -1, 1)

    def rearrange_u_data(self, u_data):
        len_u0 = len(self.reorder_u0_map)
        len_u1 = len(self.reorder_u1_map)
        len_p = len(self.reorder_p_map)
        tot_len = len_u0 + len_u1 + len_p
        u0_data = u_data[:-len_p, :, 0:1]
        u1_data = u_data[:-len_p, :, 1:2]
        p_data = u_data[-len_p:, :, 2:3]

        if isinstance(u_data, np.ndarray):
            u_data = np.zeros((tot_len, u_data.shape[1], 1))
        elif isinstance(u_data, torch.Tensor):
            u_data = torch.zeros((tot_len, u_data.shape[1], 1), 
                device=u_data.device, dtype=u_data.dtype)
        else:
            raise ValueError("Unknown data type")
        u_data[self.select_u0_map] = u0_data[self.reorder_u0_map]
        u_data[self.select_u1_map] = u1_data[self.reorder_u1_map]
        u_data[self.select_p_map] = p_data[self.reorder_p_map]
        return u_data
    
    def compute_tm_rhs(self, X, t, dt, u):
        pred_Y = u
        res = torch.zeros_like(pred_Y, device=X.device)

        n_time = pred_Y.shape[0]
        n_points = pred_Y.shape[1]
        output_dim = pred_Y.shape[-1]
        n_channel = n_time * output_dim
        n_dim = X.shape[-1]
        pred_Y = pred_Y.permute(1, 0, 2).reshape(n_points, n_channel)
        # Initialize dpred_Y with zeros
        dpred_Y = torch.zeros(n_points, n_channel, n_dim, device=X.device)

        # Compute the gradient for each channel in u
        for i in range(n_channel):
            # Compute gradients
            dpred_Y[:, i] = torch.autograd.grad(
                pred_Y[:, i], X, 
                grad_outputs=torch.ones_like(pred_Y[:, i]), retain_graph=True
            )[0]
        dpred_Y = dpred_Y.reshape(n_points, n_time, output_dim, n_dim)
        dpred_Y = dpred_Y.permute(1, 0, 2, 3)

        pred_Y = pred_Y.reshape(n_points, n_time, output_dim).permute(1, 0, 2)
        u = pred_Y[:, :, 0:1]
        v = pred_Y[:, :, 1:2]
        u_x = dpred_Y[:, :, 0:1, 0]
        u_y = dpred_Y[:, :, 0:1, 1]
        v_x = dpred_Y[:, :, 1:2, 0]
        v_y = dpred_Y[:, :, 1:2, 1]

        res[:, :, 0:1] = -(u * u_x + v * u_y)
        res[:, :, 1:2] = -(u * v_x + v * v_y)

        return dt * res
