# federated/client.py  (경량 버전)

from __future__ import annotations
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

__all__ = ["Client", "ForecastHead", "ForecastWrapper", "create_client_model"]


# ──────────────────────────────────────────────────────────────────────────────
# ❶ Forecast head & wrapper
# ──────────────────────────────────────────────────────────────────────────────

class ForecastHead(nn.Module):
    """Linear projection → reshape to (B, horizon, out_dim)."""

    def __init__(self, horizon: int, out_dim: int) -> None:
        super().__init__()
        self._h, self._d = horizon, out_dim
        self.proj = nn.LazyLinear(horizon * out_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Flatten input if needed
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        y = self.proj(x)                        # (B, H*D)
        return y.view(x.size(0), self._h, self._d)


class ForecastWrapper(nn.Module):
    """Backbone → flatten (if needed) → ForecastHead."""

    def __init__(self, backbone: nn.Module, horizon: int, out_dim: int) -> None:
        super().__init__()
        self.backbone = backbone
        self.head = ForecastHead(horizon, out_dim)

    def forward(self, x: torch.Tensor, time_features: torch.Tensor = None) -> torch.Tensor:
        # Pass time_features to backbone if it supports it (like QAPBackboneWrapper)
        if time_features is not None:
            try:
                feat = self.backbone(x, time_features)
            except TypeError:
                # Backbone doesn't support time_features, use x only
                feat = self.backbone(x)
        else:
            feat = self.backbone(x)

        if isinstance(feat, (tuple, list)):
            feat = feat[0]
        if feat.ndim > 2:                      # (B, seq, ch) → (B, F)
            feat = feat.flatten(1)
        # Ensure feat is a proper tensor before passing to head
        if not isinstance(feat, torch.Tensor):
            feat = torch.as_tensor(feat)
        return self.head(feat)


# ──────────────────────────────────────────────────────────────────────────────
# ❂ Client wrapper
# ──────────────────────────────────────────────────────────────────────────────

class Client(nn.Module):
    """
    Federated client wrapper holding a single nn.Module:
      - FedAvg  : model = backbone
      - FedPer  : model = ForecastWrapper(backbone + head)
    """
    def __init__(self, model: nn.Module, mode: str = "fedper",
                 loader: Optional[DataLoader] = None, cid: Optional[str | int] = None):
        super().__init__()
        assert mode in {"fedavg", "fedper"}
        self.mode = mode
        self.model = model
        self.loader = loader
        self.cid = cid

    def forward(self, x: torch.Tensor, *args, **kwargs):
        # DataFactory에서 (x_emb, y_emb) 2-튜플을 내므로 tf 인자가 필요 없음
        return self.model(x)

    # ---- 서버 송수신 유틸 ----
    @torch.no_grad()
    def state_dict_for_server(self) -> Dict[str, torch.Tensor]:
        """
        서버로 보낼 '공유 파라미터'만 추출.
        FedPer에서는 ForecastWrapper(backbone+head) 중 backbone.*만 보냄.
        """
        sd = self.model.state_dict()
        return {k: v.clone() for k, v in sd.items() if k.startswith("backbone.")}

    @torch.no_grad()
    def load_backbone_state_dict(self, sd: Dict[str, torch.Tensor], strict: bool = False):
        """
        서버에서 브로드캐스트된 공유 파라미터(backbone.*)만 덮어쓰기.
        """
        full = self.model.state_dict()
        for k, v in sd.items():
            if k.startswith("backbone."):
                full[k] = v
        self.model.load_state_dict(full, strict=strict)


# ──────────────────────────────────────────────────────────────────────────────
# ❸ Factory function
# ──────────────────────────────────────────────────────────────────────────────

def create_client_model(backbone: nn.Module, horizon: int, out_dim: int) -> ForecastWrapper:
    """
    Create FedPer client model using ForecastWrapper structure
    Args:
        backbone: Backbone model (with QAP integration if needed)
        horizon: Forecast horizon
        out_dim: Output dimension (number of features to predict)
    Returns:
        ForecastWrapper
    """
    return ForecastWrapper(backbone, horizon=horizon, out_dim=out_dim)