import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from .baseline import BaseFLModel

class _Head(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__(); self.fc = nn.Linear(in_dim, out_dim, bias=True)
    def forward(self, x): return self.fc(x)

class LPSoftmax(BaseFLModel):
    def __init__(self, cfg, device):
        super().__init__(cfg, device)
        self.seed = int(cfg.get("seed", 0))
        m, t = cfg.get("model", {}), cfg.get("train", {})
        self.in_dim = m.get("in_dim", None)
        self.num_classes = int(m.get("num_classes", 10))
        self.lr = float(t.get("lr", 1e-2))
        self.wd = float(t.get("weight_decay", 1e-4))
        self.epochs = int(t.get("local_epochs", 1))
        self.bs = int(self.cfg["data"].get("batch_size", 128))

    def init_global(self, enc_info=None):
        if self.in_dim is None:
            self.in_dim = int(enc_info["feature_dim"])
        head = _Head(self.in_dim, self.num_classes).to(self.device)
        return {"head": head.state_dict()}

    def client_update(self, global_state, client_data, round_idx, enc_mgr=None):
        loader = self._as_loader(client_data, shuffle=True, batch_size=self.bs, seed=(self.seed + round_idx))
        head = _Head(self.in_dim, self.num_classes).to(self.device)
        head.load_state_dict(global_state["head"]); head.train()
        opt = optim.AdamW(head.parameters(), lr=self.lr, weight_decay=self.wd)
        ce = nn.CrossEntropyLoss()
        n = 0
        for _ in range(self.epochs):
            for xb, yb in loader:
                xb, yb = xb.to(self.device), yb.to(self.device)
                opt.zero_grad(); logits = head(xb); loss = ce(logits, yb)
                loss.backward(); opt.step(); n += yb.numel()
        return {"head": head.state_dict()}, {"head": {"scalar": int(n)}}

    @torch.no_grad()
    def evaluate(self, global_state, testset, enc_mgr=None):
        loader = self._as_loader(testset, shuffle=False, batch_size=self.bs)
        head = _Head(self.in_dim, self.num_classes).to(self.device)
        head.load_state_dict(global_state["head"]); head.eval()
        ce = nn.CrossEntropyLoss(); tot, corr, cnt = 0.0, 0, 0
        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)
            logits = head(xb); tot += ce(logits, yb).item()
            corr += (logits.argmax(1) == yb).sum().item(); cnt += yb.numel()
        return tot / max(1,len(loader)), 100.0*corr/max(1,cnt)
