from typing import Dict, Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader

T = torch.Tensor


class ProbModel(nn.Module):
    def __init__(self, dataloader: DataLoader):
        super().__init__()
        assert isinstance(dataloader, torch.utils.data.DataLoader)
        self.dataloader = dataloader

    def log_prob(self, data: T, target: T) -> Dict[str, T]:
        '''
        If minibatches have to be sampled due to memory constraints,
        a standard PyTorch dataloader can be used.
        "Infinite minibatch sampling" can be achieved by calling:
        data, target = next(dataloader.__iter__())
        next(Iterable.__iter__()) calls a single mini-batch sampling step
        But since it's not in a loop, we can call it ad infinitum
        '''
        raise NotImplementedError

    def sample_minibatch(self) -> Tuple[T, T]:
        '''
        Idea:
        Hybrid Monte Carlo Samplers require a constant tuple (data, target) to compute trajectories
        '''
        raise NotImplementedError

    def reset_parameters(self) -> None:
        raise NotImplementedError

    def predict(self, chain):
        raise NotImplementedError

    def pretrain(self):
        pass
