import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from .baseline import BaseFLModel
from torch.utils.data import TensorDataset, DataLoader


class _Head(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias=True)
    def forward(self, x):
        return self.fc(x)

def _matrix_to_heads(sd, C):
    W, b = sd["fc.weight"], sd["fc.bias"]
    return [{"fc.weight": W[c:c+1].clone(), "fc.bias": b[c:c+1].clone()} for c in range(C)]

def _heads_to_matrix(heads):
    W = torch.cat([h["fc.weight"] for h in heads], dim=0)
    B = torch.cat([h["fc.bias"]   for h in heads], dim=0)
    return {"fc.weight": W, "fc.bias": B}



class LPOvA(BaseFLModel):
    def __init__(self, cfg, device):
        super().__init__(cfg, device)
        m, t = cfg.get("model", {}), cfg.get("train", {})
        self.seed = int(cfg.get("seed", 0))
        self.in_dim = m.get("in_dim", None)
        self.C = int(m.get("num_classes", 10))
        self.stage1_rounds = int(m.get("num_stage1_rounds", 0))
        self.min_samples_stage1 = int(m.get("min_samples_stage1", 0))
        self.lr = float(t.get("lr", 1e-2))
        self.wd = float(t.get("weight_decay", 1e-4))
        self.epochs = int(t.get("local_epochs", 1))
        self.bs = int(self.cfg["data"].get("batch_size", 128))

    def _new_head1(self):
        return _Head(self.in_dim, 1).to(self.device)

    def _make_adamw(self, params):
        try:
            return optim.AdamW(params, lr=self.lr, weight_decay=self.wd, fused=True)
        except TypeError:
            return optim.AdamW(params, lr=self.lr, weight_decay=self.wd)
        
    def get_requirements(self):
        return {"input_type": "features"}

    def init_global(self, enc_info=None):
        if self.in_dim is None:
            assert enc_info and "feature_dim" in enc_info, "feature_dim 필요"
            self.in_dim = int(enc_info["feature_dim"])
        headC = _Head(self.in_dim, self.C).to(self.device)
        return {"head": headC.state_dict()}
    
    
    @torch.no_grad()
    def _count_labels(self, loader):
        cnt = torch.zeros(self.C, dtype=torch.long)
        for _, yb in loader:
            cnt += torch.bincount(yb, minlength=self.C)
        return cnt

    def _make_adamw(self, params):
        try:
            return optim.AdamW(params, lr=self.lr, weight_decay=self.wd, fused=True)
        except TypeError:
            return optim.AdamW(params, lr=self.lr, weight_decay=self.wd)

    def _build_pos_loaders_once(self, X_cpu, y_cpu):
        C = self.C
        loaders = {}
        ones_cache = {}
        binc = torch.bincount(y_cpu, minlength=C)
        eligible = (binc >= self.min_samples_stage1) if self.min_samples_stage1 > 0 else torch.ones(C, dtype=torch.bool)
        for c in torch.nonzero(binc > 0, as_tuple=True)[0].tolist():
            if not eligible[c]:
                continue
            idx = (y_cpu == c).nonzero(as_tuple=True)[0]
            Xc = X_cpu.index_select(0, idx)  
            yc = torch.ones(Xc.shape[0], 1, dtype=torch.float32)
            ds = TensorDataset(Xc, yc)
            loaders[c] = DataLoader(ds, batch_size=self.bs, shuffle=True,
                                    pin_memory=True, num_workers=0, drop_last=False)
            ones_cache[c] = yc
        return loaders, eligible

    def _stage1_train_with_pos_loaders(self, seed_rows, pos_loaders, eligible_mask):

        heads, optims = {}, {}
        bce = nn.BCEWithLogitsLoss()
        pos_counts = torch.zeros(self.C, dtype=torch.long)
        seen_total = 0

        for c in pos_loaders.keys():
            h = _Head(self.in_dim, 1).to(self.device)
            h.load_state_dict(seed_rows[c]); h.train(True)
            heads[c] = h
            optims[c] = self._make_adamw(h.parameters())

        for _ in range(self.epochs):
            for c, loader_c in pos_loaders.items():
                if not eligible_mask[c]:
                    continue
                h, opt = heads[c], optims[c]
                for xb, y1 in loader_c:
                    xb = xb.to(self.device, non_blocking=True, dtype=torch.float32)
                    y1 = y1.to(self.device, non_blocking=True, dtype=torch.float32)
                    seen_total += y1.size(0)

                    opt.zero_grad(set_to_none=True)
                    logit = h(xb)                 
                    loss = bce(logit, y1)         
                    loss.backward(); opt.step()

                    pos_counts[c] += y1.size(0)

        merged = [heads[c].state_dict() if c in heads else seed_rows[c] for c in range(self.C)]
        upd_headC = _heads_to_matrix(merged)
        return upd_headC, pos_counts, seen_total

    def client_update(self, global_state, client_data, round_idx, enc_mgr=None):
        loader = self._as_loader(client_data, shuffle=True, batch_size=self.bs, seed=(self.seed + round_idx))
        is_stage1 = (round_idx < self.stage1_rounds)
        if is_stage1:
            seed_rows = _matrix_to_heads(global_state["head"], self.C)
            X_cpu, y_cpu = client_data
            pos_loaders, eligible_mask = self._build_pos_loaders_once(
                X_cpu.cpu(), y_cpu.cpu()
            )
            if len(pos_loaders) == 0:
                return None, None

            upd_headC, pos_counts, _ = self._stage1_train_with_pos_loaders(
                seed_rows, pos_loaders, eligible_mask
            )

            return {"head": upd_headC}, {"head": {"row": pos_counts.cpu()}}

        headC = _Head(self.in_dim, self.C).to(self.device)
        headC.load_state_dict(global_state["head"]); headC.train(True)
        opt = self._make_adamw(headC.parameters())
        bce = nn.BCEWithLogitsLoss()
        total_seen = 0
        for _ in range(self.epochs):
            for xb, yb in loader:
                xb = xb.to(self.device, non_blocking=True, dtype=torch.float32)
                yb = yb.to(self.device, non_blocking=True)
                total_seen += yb.numel()
                opt.zero_grad(set_to_none=True)
                logits = headC(xb)
                y_ova  = F.one_hot(yb, num_classes=self.C).float()
                loss   = bce(logits, y_ova)
                loss.backward(); opt.step()
        return {"head": headC.state_dict()}, {"head": {"scalar": int(total_seen)}}

    @torch.no_grad()
    def evaluate(self, global_state, testset, enc_mgr=None):
        loader = self._as_loader(testset, shuffle=False, batch_size=self.bs)
        headC = _Head(self.in_dim, self.C).to(self.device)
        headC.load_state_dict(global_state["head"]); headC.eval()
        ce = nn.CrossEntropyLoss()
        tot, corr, nbat, cnt = 0.0, 0, 0, 0
        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)
            logits = headC(xb)
            tot += ce(logits, yb).item()
            corr += (logits.argmax(1) == yb).sum().item()
            cnt += yb.numel(); nbat += 1
        return tot / max(1, nbat), 100.0 * corr / max(1, cnt)
