from typing import Union

import pytorch_lightning as pl

from ..io.persistence import get_checkpoints_callback
from ..io.logging import get_tb_logger


class Trainer(pl.Trainer):
    
    def __init__(
        self,
        task_name: Union[str, list[str]],
        enable_checkpointing: bool = True,
        *args,
        **kwargs,
    ) -> None:
        callbacks = []
        if enable_checkpointing:
            callbacks.append(get_checkpoints_callback(task_name))
        return super().__init__(
            *args,
            # TODO: don't hardcode the paths
            logger=get_tb_logger(task_name),
            # TODO: add callbacks from args and kwargs if they are specified
            callbacks=callbacks,
            enable_checkpointing=enable_checkpointing,
            **kwargs,
        )
