import torch
import torch.nn as nn

from .rnn_mawm import MAWM


class GWM(nn.Module):
    def __init__(self, encoder, action_encoder, rnn, decoder):
        super(GWM, self).__init__()
        # TODO
        pass
