import torch.nn as nn

from models.diffuser import WorldModel
from models.actor_critic import ActorCritic


class Agent(nn.Module):
    def __init__(self, wm: WorldModel, ac: ActorCritic, tok, gpt) -> None:
        super().__init__()
        self.wm = wm
        self.ac = ac
        self.tok = tok
        self.gpt = gpt

    @property
    def device(self):
        return self.wm.device
