# file: prism/core/base_objects.py
from abc import ABC, abstractmethod

import pytorch_lightning as pl
import torch.nn as nn

from prism.utils.config import AttrDict


class BaseModel(nn.Module, ABC):
    def __init__(self, config):
        super().__init__()
        self.config = config

    @abstractmethod
    def forward(self, *args, **kwargs):
        raise NotImplementedError


class BaseSystem(pl.LightningModule, ABC):
    def __init__(self, config):
        super().__init__()
        self.config = config
        hparams_to_save = config.to_dict() if isinstance(config, AttrDict) else config
        self.save_hyperparameters(hparams_to_save)


class BaseCallback(pl.Callback, ABC):
    def __init__(self, config):
        super().__init__()
        self.config = config


class BaseLoss(nn.Module, ABC):
    def __init__(self, config):
        super().__init__()
        self.config = config

    @abstractmethod
    def forward(self, *args, **kwargs):
        raise NotImplementedError


class BaseMetric(ABC):
    def __init__(self, config):
        self.config = config

    @abstractmethod
    def calculate(self, **kwargs):
        raise NotImplementedError