"""
Base model class that defines the interface for all models.
"""
from abc import ABC, abstractmethod
import torch
import torch.nn as nn


class BaseModel(nn.Module, ABC):
    """Base class for all models in the project."""
    
    def __init__(self):
        super().__init__()
    
    @abstractmethod
    def forward(self, x):
        """Forward pass of the model.
        
        Args:
            x: Input tensor
            
        Returns:
            Output tensor
        """
        pass
    
    @abstractmethod
    def loss(self, pred, target):
        """Compute the loss for the model.
        
        Args:
            pred: Model predictions
            target: Ground truth targets
            
        Returns:
            Loss value
        """
        pass
    
    def save(self, path):
        """Save model weights.
        
        Args:
            path: Path to save the model weights
        """
        torch.save(self.state_dict(), path)
    
    def load(self, path):
        """Load model weights.
        
        Args:
            path: Path to load the model weights from
        """
        self.load_state_dict(torch.load(path)) 