import torch
from torch.utils.data import DataLoader, TensorDataset

class BaseFLModel:
    def __init__(self, cfg, device):
        self.cfg = cfg; self.device = device

    def get_requirements(self):
        return {"input_type":"features", "accept_precompute": True, "needs_encoder": False, "update_keys":["head"]}

    def init_global(self, enc_info=None):
        raise NotImplementedError

    def client_update(self, global_state, client_data, round_idx, enc_mgr=None):
        raise NotImplementedError

    @torch.no_grad()
    def evaluate(self, global_state, testset, enc_mgr=None):
        raise NotImplementedError
    
    def _as_loader(self, tup, shuffle, batch_size, seed=0):
        x, y = tup
        dataset = TensorDataset(x, y)
        gen = None
        if shuffle and seed is not None:
            gen = torch.Generator()
            gen.manual_seed(int(seed))
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            generator=gen,
            pin_memory=True
        )

def load_model_api_from_cfg(cfg, device):
    name = cfg.get("model_name","").lower()
    if name in ("lp_softmax","softmax_lp","lp-softmax"):
        from .lp_softmax import LPSoftmax
        return LPSoftmax(cfg, device)
    elif name in ("ova_lp","lp_ova","ova-lp"):
        from .lp_ova import LPOvA
        return LPOvA(cfg, device)
    elif name in ("pfpt"):
        from .pfpt import PFPT
        return PFPT(cfg, device)
    elif name in ("fmoe"):
        from .fmoe import FMoE
        return FMoE(cfg, device)
    else:
        raise ValueError(f"unknown model_name: {name}")
