"""
Base classes for equation solvers in TensorGalerkin
"""

import torch  
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Optional


class EquationDatasetSequential(nn.Module):
    """Base class for time-dependent equations"""
    
    def __init__(self):
        super().__init__()
        self.mesh = None
        self.graph = None
        self.steps = None

    @abstractmethod
    def compute_residual(self, U: torch.Tensor) -> torch.Tensor:
        """
        Compute the residual of the equation.
        
        Parameters:
        -----------
            U: torch.Tensor (n_node,)
                Prediction from the model
                
        Returns:
        --------
            R: torch.Tensor (n_node,)
                Residual of the equation
        """
        raise NotImplementedError()
    

class EquationDatasetStatic(nn.Module):
    """Base class for static (steady-state) equations"""
    
    def __init__(self):
        super().__init__()
        self.mesh = None
        self.graph = None

    @abstractmethod
    def form(self, phi: torch.Tensor, basis: torch.Tensor, 
             gradphi: torch.Tensor, gradbasis: torch.Tensor) -> torch.Tensor:
        """
        Weak form of the equation
        
        Parameters:
        -----------
            phi: torch.Tensor [batch_size]
                Solution values at quadrature points
            basis: torch.Tensor [batch_size, n_basis, 2]
                Basis function values
            gradphi: torch.Tensor [batch_size, 2]
                Gradient of solution at quadrature points
            gradbasis: torch.Tensor [batch_size, n_basis, 2]
                Gradient of basis functions
                
        Returns:
        --------
            torch.Tensor [batch_size, n_basis]
                Weak form contribution
        """
        raise NotImplementedError()
    
    # def to(self, device: torch.device):
    #     """Move equation components to device"""
    #     if hasattr(self, 'shape_val'):
    #         self.shape_val = self.shape_val.to(device)
    #     if hasattr(self, 'shape_grad'):
    #         self.shape_grad = self.shape_grad.to(device)
    #     if hasattr(self, 'weight'):
    #         self.weight = self.weight.to(device)
    #     if hasattr(self, 'elements'):
    #         self.elements = self.elements.to(device)
    #     if hasattr(self, 'jac_det'):
    #         self.jac_det = self.jac_det.to(device)
    #     if hasattr(self, 'JxW'):
    #         self.JxW = self.JxW.to(device)
    #     if hasattr(self, 'boundary_mask'):
    #         self.boundary_mask = self.boundary_mask.to(device)
    #     if hasattr(self, 'boundary_value'):
    #         self.boundary_value = self.boundary_value.to(device)
    #     if hasattr(self, 'ele2node'):
    #         self.ele2node = self.ele2node.to(device)
    
    # def type(self, dtype: torch.dtype):
    #     """Convert equation components to specified dtype"""
    #     if hasattr(self, 'shape_val'):
    #         self.shape_val = self.shape_val.type(dtype)
    #     if hasattr(self, 'shape_grad'):
    #         self.shape_grad = self.shape_grad.type(dtype)
    #     if hasattr(self, 'weight'):
    #         self.weight = self.weight.type(dtype)
    #     if hasattr(self, 'jac_det'):
    #         self.jac_det = self.jac_det.type(dtype)
    #     if hasattr(self, 'JxW'):
    #         self.JxW = self.JxW.type(dtype)
    #     if hasattr(self, 'boundary_value'):
    #         self.boundary_value = self.boundary_value.type(dtype)