import warnings
import torch
import torch.nn as nn

from lib.strategies.base import BaseStrategy

warnings.filterwarnings("ignore")

class FlavaLowerBoundStrategy(BaseStrategy):
    def __init__(self,
        model: nn.Module,
        stream: object,
        n_epochs: int,
        lr: float,
        batch_size: int,
        output_filename: str,
        device: torch.device,
        **kwargs
    ):
        super().__init__(
            model=model,
            stream=stream,
            n_epochs=n_epochs,
            lr=lr,
            batch_size=batch_size,
            device=device,
            output_filename=output_filename
        )
        
    def forward(self, inputs):
        return self.model(inputs)