from dataclasses import dataclass
from typing import Callable, List, Optional

import torch
from torch.utils.data import DataLoader, Dataset


@dataclass
class Client:
    task: str
    model: torch.nn.Module
    local_optimizer: torch.optim.Optimizer
    train_dataset: Dataset
    test_dataset: Dataset
    train_loader: DataLoader
    test_loader: DataLoader
    public_loader: DataLoader
    local_train_func: Callable
    local_eval_fun: Callable
    train_args: dict
    server_lr: float
    num_classes: Optional[int] = None
