import torch
import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.tmpdedata import TimePDEData
from utils.fenics_utils import compute_dof_map_burgers2d


class GS(TimePDEData):
    """
    Prepare data for the 2D Diffusion-Reaction Gray-ScottModel (GS).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt(
            "data/ref/grayscott.dat", comments="%")
    
    def preprocess(self):
        self.ref_data = self.ref_data.reshape(21, -1, 5)
        self.ref_data = np.concatenate((
            self.ref_data[0, :, :2],
            np.transpose(self.ref_data[:, :, 3:5], (1, 0, 2))\
                .reshape(-1, 21*2)
        ), axis=1)

    def init_problem(self):
        self.preprocess()

        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = RectangleMesh(
            Point(bbox[0], bbox[2]),
            Point(bbox[1], bbox[3]),
            self.config["Grid Size (x-direction)"],
            self.config["Grid Size (y-direction)"])

        # Build the function space (Taylor-Hood)
        P2 = VectorElement("P", mesh.ufl_cell(), 2)
        W = FunctionSpace(mesh, P2)

        # Get the dof coordinates
        self.coords = W.sub(0).collapse()\
            .tabulate_dof_coordinates()
        
        # Initial condition (interpolation)
        def ic_func(x, component):
            if component == 0:
                return 1 - np.exp(-80 * ((x[:, 0] + 0.05)**2 + (x[:, 1] + 0.02)**2))
            else:
                return np.exp(-80 * ((x[:, 0] - 0.05)**2 + (x[:, 1] - 0.02)**2))
        self.u0 = np.stack((
            ic_func(self.coords, 0),
            ic_func(self.coords, 1)
        ), axis=1)

        # Define variational problem
        dt = (bbox[-1] - bbox[-2]) / (
            self.config["Grid Size (t-direction)"] * 
            self.config["Sub-Time Intervals"])
        D = (1e-5, 5e-6)     # diffusivities
        gamma = 0.04         # feed rate
        kappa = 0.1 - gamma  # reaction rate
        u = TrialFunction(W)
        v = TestFunction(W)

        # diffusion terms
        def a_form(u, v):
            return inner(D[0]*grad(u[0]), grad(v[0]))*dx \
                + inner(D[1]*grad(u[1]), grad(v[1]))*dx
        
        def f_linear(u, v):
            return gamma*(1 - u[0])*v[0]*dx \
                - (gamma + kappa)*u[1]*v[1]*dx

        F = inner(u,v)*dx + dt*(a_form(u,v) - f_linear(u,v))
        a = lhs(F)
        L = rhs(F)
        k = inner(u,v)*dx
        
        # Compute matrix and right-hand side vector
        A = assemble(a)
        b = assemble(L)
        # Mass matrix for computing right-hand side vector
        # at each time step
        K = assemble(k)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = b.get_local().reshape(-1, 1)
        csr = as_backend_type(K).mat().getValuesCSR()[::-1]
        self.K = sp.csr_matrix(csr)

        # Compute dof map
        self.select_u0_map, self.reorder_u0_map, \
            self.select_u1_map, self.reorder_u1_map = compute_dof_map_burgers2d(W)

    def rearrange_u_data(self, u_data):
        len_u0 = len(self.reorder_u0_map)
        len_u1 = len(self.reorder_u1_map)
        tot_len = len_u0 + len_u1
        u0_data = u_data[:, :, 0:1]
        u1_data = u_data[:, :, 1:2]

        if isinstance(u_data, np.ndarray):
            u_data = np.zeros((tot_len, u_data.shape[1], 1))
        elif isinstance(u_data, torch.Tensor):
            u_data = torch.zeros((tot_len, u_data.shape[1], 1), 
                device=u_data.device, dtype=u_data.dtype)
        else:
            raise ValueError("Unknown data type")
        u_data[self.select_u0_map] = u0_data[self.reorder_u0_map]
        u_data[self.select_u1_map] = u1_data[self.reorder_u1_map]
        return u_data
    
    def compute_tm_rhs(self, X, t, dt, u):
        u_0 = u[..., 0:1]
        u_1 = u[..., 1:2]
        f = torch.concat((
            -u_0*u_1**2,
            u_0*u_1**2
        ), dim=-1)
        return f
