from pathlib import Path

import torch
import torch.nn as nn

from models.diffuser import WorldModel
from models.actor_critic import ActorCritic


class Agent(nn.Module):
    def __init__(self, wm: WorldModel, ac: ActorCritic) -> None:
        super().__init__()
        self.wm = wm
        self.ac = ac

    @property
    def device(self):
        return self.wm.device

    def load(self, path_to_ckpt: Path, load_wm: bool = True, load_ac: bool = True) -> None:
        ckpt = torch.load(path_to_ckpt)
        if load_wm: self.wm.load_state_dict(ckpt['wm'])
        if load_ac: self.ac.load_state_dict(ckpt['ac'])
