import os
import torch
import torch.nn as nn
import torch.optim as optim
from dataclasses import dataclass

@dataclass
class ModelCheckpoint:
    ckpt_dir: str
    hide: nn.Module
    find: nn.Module
    opt: optim.Optimizer
    
    def save(self, name):
        checkpoint = {
            'hide': self.hide.state_dict(),
            'find': self.find.state_dict(),
            'opt': self.opt.state_dict()
        }
        torch.save(checkpoint, os.path.join(self.ckpt_dir, name))