from abc import ABC, abstractmethod
import torch

class Product(ABC) :
    """
    Abstract product class
    """
    def __init__(self, state_dim:int, initial_state : torch.Tensor) :
        self.state_dim = state_dim
        self.initial_state = initial_state

    @abstractmethod
    def get_copy(self) :
        raise NotImplementedError("")
    
    @abstractmethod
    def transition(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        raise NotImplementedError("")

    @abstractmethod
    def cost(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        raise NotImplementedError("")