from abc import ABC, abstractmethod


class BaseImitationLearning(ABC):
    """
    Base abstract class for imitation learning algorithms.
    
    This class defines the common interface that all imitation learning
    algorithms should implement, including training, action sampling,
    and model persistence methods.
    """
    
    def __init__(self, state_dim, action_dim, max_action, device, lr=3e-4):
        """
        Initialize the base imitation learning agent.
        
        Args:
            state_dim (int): Dimension of the state space
            action_dim (int): Dimension of the action space
            max_action (float): Maximum action value for normalization
            device (torch.device): Device to run computations on (CPU/GPU)
            lr (float): Learning rate for optimization
        """
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.device = device
        self.lr = lr
        self.step = 0
        
    @abstractmethod
    def train(self, replay_buffer, iterations, batch_size=100, log_writer=None):
        """
        Train the imitation learning model.
        
        Args:
            replay_buffer: Buffer containing expert demonstrations
            iterations (int): Number of training iterations
            batch_size (int): Batch size for training
            log_writer: Logger for recording training metrics
            
        Returns:
            dict: Training metrics
        """
        pass
    
    @abstractmethod
    def sample_action(self, state):
        """
        Sample an action given the current state.
        
        Args:
            state (np.ndarray or torch.Tensor): Current state
            
        Returns:
            np.ndarray: Sampled action
        """
        pass
    
    @abstractmethod
    def save_model(self, dir, id=None):
        """
        Save the trained model to disk.
        
        Args:
            dir (str): Directory to save the model
            id (str, optional): Model identifier for versioning
        """
        pass
    
    @abstractmethod
    def load_model(self, dir, id=None, map_location=None):
        """
        Load a trained model from disk.
        
        Args:
            dir (str): Directory containing the saved model
            id (str, optional): Model identifier for versioning
            map_location: Device mapping for loading
        """
        pass