import torch
import numpy as np
import scipy.sparse as sp
from dolfin import *

from src.tmpdedata import TimePDEData


class KS(TimePDEData):
    """
    Prepare data for the Kuramoto-Sivashinsky Model (KS).
    """
    def __init__(self) -> None:
        super().__init__()
        self.ref_data = np.loadtxt(
            "data/ref/Kuramoto_Sivashinsky.dat", comments="%")
    
    def preprocess(self):
        self.ref_data = self.ref_data.reshape(-1, 251, 3)
        self.ref_data = np.concatenate((
            self.ref_data[:, 0, :1],
            self.ref_data[:, :, 2:]\
                .reshape(-1, 251)
        ), axis=1)

    def init_problem(self):
        self.preprocess()

        # Create mesh and define function space
        bbox = self.config["Bounding Box"]
        mesh = IntervalMesh(
            self.config["Grid Size (x-direction)"],
            bbox[0], bbox[1])

        # Build the function space
        V = FunctionSpace(mesh, "Lagrange", 3)

        # Get the dof coordinates
        self.coords = V.tabulate_dof_coordinates()
        
        # Initial condition (interpolation)
        self.u0 = np.cos(self.coords) * \
            (1 + np.sin(self.coords))

        beta = 100/16**2
        gamma = 100/16**4
        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        u_x = u.dx(0)
        u_xx = u_x.dx(0)
        u_xxx = u_xx.dx(0)
        v_x = v.dx(0)
        dt = (bbox[-1] - bbox[-2]) / (
            self.config["Grid Size (t-direction)"] * 
            self.config["Sub-Time Intervals"])

        a = u*v*dx - dt*beta*u_x*v_x*dx - dt*gamma*u_xxx*v_x*dx
        k = inner(u,v)*dx
        
        # Compute matrix and right-hand side vector
        A = assemble(a)
        # Mass matrix for computing right-hand side vector
        # at each time step
        K = assemble(k)

        # Convert PETSc matrix and vector to scipy sparse arrays
        csr = as_backend_type(A).mat().getValuesCSR()[::-1]
        self.A = sp.csr_matrix(csr)
        self.b = np.array(0) # Dummy b
        csr = as_backend_type(K).mat().getValuesCSR()[::-1]
        self.K = sp.csr_matrix(csr)

    def compute_tm_rhs(self, X, t, dt, u):
        pred_Y = u
        n_time = pred_Y.shape[0]
        n_points = pred_Y.shape[1]
        output_dim = pred_Y.shape[-1]
        n_channel = n_time * output_dim
        n_dim = X.shape[-1]
        pred_Y = pred_Y.permute(1, 0, 2).reshape(n_points, n_channel)
        # Initialize dpred_Y with zeros
        dpred_Y = torch.zeros(n_points, n_channel, n_dim, device=X.device)

        # Compute the gradient for each channel in u
        for i in range(n_channel):
            # Compute gradients
            dpred_Y[:, i] = torch.autograd.grad(
                pred_Y[:, i], X, 
                grad_outputs=torch.ones_like(pred_Y[:, i]), retain_graph=True
            )[0]
        dpred_Y = dpred_Y.reshape(n_points, n_time, output_dim, n_dim)
        dpred_Y = dpred_Y.permute(1, 0, 2, 3)

        pred_Y = pred_Y.reshape(n_points, n_time, output_dim).permute(1, 0, 2)
        alpha = 100/16
        u = pred_Y[:, :, 0:1]
        u_x = dpred_Y[:, :, 0:1, 0]

        return dt * (-alpha*u*u_x)
