import torch
import numpy as np
import scipy.sparse as sp
from scipy.interpolate import NearestNDInterpolator
from dolfin import *

from src.tmpdedata import TimePDEData
from utils.fenics_utils import compute_dof_map_burgers2d


class Burgers2D_C(TimePDEData):
    """
    Prepare data for the 2D Burgers' equation (Burgers2d-C).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt(
            "data/ref/burgers2d_0.dat", comments="%")

    def init_problem(self):
        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        with XDMFFile(MPI.comm_world, 
            'data/burgers2d_mesh.xdmf') as infile:
            mesh = Mesh()
            infile.read(mesh)
        
        # Sub domain for Periodic boundary condition
        L = 4
        class PeriodicBoundary(SubDomain):
            # Left and bottom boundary are "target domain" G
            def inside(self, x, on_boundary):
                return bool((near(x[0], bbox[0]) or 
                    near(x[1], bbox[2])) and on_boundary)

            # Map right boundary (H) to left boundary (G)
            # and top boundary (H) to bottom boundary (G)
            def map(self, x, y):
                if near(x[0], L) and near(x[1], L):
                    y[0] = x[0] - L
                    y[1] = x[1] - L
                elif near(x[0], L):
                    y[0] = x[0] - L
                    y[1] = x[1]
                else:
                    y[0] = x[0]
                    y[1] = x[1] - L

        # Build the function space (Taylor-Hood)
        P2 = VectorElement("P", mesh.ufl_cell(), 2)
        W = FunctionSpace(mesh, P2, 
            constrained_domain=PeriodicBoundary())

        # Get the dof coordinates
        self.coords = W.sub(0).collapse()\
            .tabulate_dof_coordinates()
        
        # Initial condition (interpolation)
        self.u0 = np.concatenate((
            NearestNDInterpolator(
                self.ref_data[:, :2], self.ref_data[:, 2:3])(self.coords),
            NearestNDInterpolator(
                self.ref_data[:, :2], self.ref_data[:, 3:4])(self.coords)
        ), axis=1)

        # Define variational problem
        dt = (bbox[-1] - bbox[-2]) / (
            self.config["Grid Size (t-direction)"] * 
            self.config["Sub-Time Intervals"])
        nu = 0.001
        u = TrialFunction(W)
        v = TestFunction(W)
        a = inner(u,v)*dx + \
            dt*nu*inner(grad(u), grad(v))*dx
        k = inner(u,v)*dx
        
        # Compute matrix and right-hand side vector
        A = assemble(a)
        # Mass matrix for computing right-hand side vector
        # at each time step
        K = assemble(k)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = np.array(0) # Dummy b
        csr = as_backend_type(K).mat().getValuesCSR()[::-1]
        self.K = sp.csr_matrix(csr)

        # Compute dof map
        self.select_u0_map, self.reorder_u0_map, \
            self.select_u1_map, self.reorder_u1_map = compute_dof_map_burgers2d(W)

    def rearrange_u_data(self, u_data):
        len_u0 = len(self.reorder_u0_map)
        len_u1 = len(self.reorder_u1_map)
        tot_len = len_u0 + len_u1
        u0_data = u_data[:, :, 0:1]
        u1_data = u_data[:, :, 1:2]

        if isinstance(u_data, np.ndarray):
            u_data = np.zeros((tot_len, u_data.shape[1], 1))
        elif isinstance(u_data, torch.Tensor):
            u_data = torch.zeros((tot_len, u_data.shape[1], 1), 
                device=u_data.device, dtype=u_data.dtype)
        else:
            raise ValueError("Unknown data type")
        u_data[self.select_u0_map] = u0_data[self.reorder_u0_map]
        u_data[self.select_u1_map] = u1_data[self.reorder_u1_map]
        return u_data
    
    def compute_tm_rhs(self, X, t, dt, u):
        pred_Y = u
        res = torch.zeros_like(pred_Y, device=X.device)

        n_time = pred_Y.shape[0]
        n_points = pred_Y.shape[1]
        output_dim = pred_Y.shape[-1]
        n_channel = n_time * output_dim
        n_dim = X.shape[-1]
        pred_Y = pred_Y.permute(1, 0, 2).reshape(n_points, n_channel)
        # Initialize dpred_Y with zeros
        dpred_Y = torch.zeros(n_points, n_channel, n_dim, device=X.device)

        # Compute the gradient for each channel in u
        for i in range(n_channel):
            # Compute gradients
            dpred_Y[:, i] = torch.autograd.grad(
                pred_Y[:, i], X, 
                grad_outputs=torch.ones_like(pred_Y[:, i]), retain_graph=True
            )[0]
        dpred_Y = dpred_Y.reshape(n_points, n_time, output_dim, n_dim)
        dpred_Y = dpred_Y.permute(1, 0, 2, 3)

        pred_Y = pred_Y.reshape(n_points, n_time, output_dim).permute(1, 0, 2)
        u = pred_Y[:, :, 0:1]
        v = pred_Y[:, :, 1:2]
        u_x = dpred_Y[:, :, 0:1, 0]
        u_y = dpred_Y[:, :, 0:1, 1]
        v_x = dpred_Y[:, :, 1:2, 0]
        v_y = dpred_Y[:, :, 1:2, 1]

        res[:, :, 0:1] = -(u * u_x + v * u_y)
        res[:, :, 1:2] = -(u * v_x + v * v_y)

        return dt * res
