"""model/federated_model_factory.py  ── Federated Model Factory

QAP 통합 백본 모델 생성 및 차원 추론 기능 제공
"""

from typing import Any, Dict, Optional, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# 백본 모델 임포트
from .train_model_factory import create_backbone_model


# ──────────────────────────────────────────────────────────────────────────────
# Simple Projection Module (QAP의 간단한 대안)
# ──────────────────────────────────────────────────────────────────────────────

class SimpleProjectionModule(nn.Module):
    """
    단순 선형 투영으로 heterogeneous features를 d_model로 통일
    QAP의 간단한 대안 (attention 없이 linear projection만 사용)
    """
    def __init__(self, F_client: int, d_model: int):
        super().__init__()
        # 개인화 파라미터 (클라이언트마다 F_client 다름)
        self.proj = nn.Linear(F_client, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.d_model = d_model

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [B, L, C_i] - 각 클라이언트의 raw features
        Returns:
            [B, L, d_model] - 통일된 차원
        """
        return self.norm(self.proj(x))


class ProjectionBackboneWrapper(nn.Module):
    """
    Simple Projection + Time feature fusion + Backbone
    QAPBackboneWrapper의 간단한 버전
    """
    def __init__(self, backbone: nn.Module, F_client: int, d_model: int,
                 time_dim: Optional[int] = None, fuse: str = "cat"):
        super().__init__()
        assert fuse in ("add", "cat"), f"fuse must be 'add' or 'cat', got {fuse}"

        # Simple projection (QAP 대신)
        self.projection = SimpleProjectionModule(F_client, d_model)
        self.backbone = backbone
        self.d_model = d_model
        self.fuse = fuse

        # Time feature processing (QAPBackboneWrapper와 동일)
        if time_dim is not None:
            self.time_proj = nn.Linear(time_dim, d_model)

            if fuse == "cat":
                self.cat_proj = nn.Linear(d_model * 2, d_model)
            else:
                self.cat_proj = None
        else:
            self.time_proj = None
            self.cat_proj = None

    def forward(self, x: torch.Tensor, time_features: Optional[torch.Tensor] = None):
        """
        Args:
            x: [B, L, C] raw input
            time_features: [B, L, t_dim] optional time features
        """
        # Apply simple projection
        proj_out = self.projection(x)  # [B, L, d_model]

        # Time feature integration (QAPBackboneWrapper와 동일)
        if time_features is not None and self.time_proj is not None:
            time_emb = self.time_proj(time_features)  # [B, L, d_model]

            if self.fuse == "add":
                proj_out = proj_out + time_emb
            else:  # cat
                combined = torch.cat([proj_out, time_emb], dim=-1)  # [B, L, 2*d_model]
                proj_out = self.cat_proj(combined)  # [B, L, d_model]

        # Pass through backbone
        return self.backbone(proj_out)


# ──────────────────────────────────────────────────────────────────────────────
# QAP Module for server backbone
# ──────────────────────────────────────────────────────────────────────────────

class QAP_ServerModule(nn.Module):
    """
    QAP module for server backbone - contains only shared parameters
    Client-specific slot_embed is handled separately at client level
    """
    def __init__(self, d_model: int = 128, num_heads: int = 8, num_queries: int = 1,
                 use_side_channel: bool = True):
        super().__init__()

        # ── 공유(서버) 파라미터들 ──
        self.value_proj = nn.Linear(1, d_model, bias=False)
        self.queries    = nn.Parameter(torch.randn(num_queries, d_model) / (d_model ** 0.5))
        self.attn       = nn.MultiheadAttention(d_model, num_heads, batch_first=True)

        if use_side_channel:
            self.fuse = nn.Sequential(
                nn.Linear(d_model * 3, d_model), nn.GELU(), nn.Linear(d_model, d_model)
            )
        else:
            self.fuse = nn.Identity()

        self.norm = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(0.1)

        self.d_model = d_model
        self.use_side_channel = use_side_channel

    def forward(self, x_blc: torch.Tensor, slot_embed: nn.Embedding):
        """
        x_blc: [B, L, C] input tensor
        slot_embed: client-specific slot embedding
        """
        # BLC -> BCL로 맞춘 뒤 기존 로직 사용
        x_bcl = x_blc.permute(0, 2, 1).contiguous()     # [B, C, L]
        B, F, L = x_bcl.shape
        d = self.d_model

        # ① 값 투영 + 로컬 slot 임베딩
        v = self.value_proj(x_bcl.unsqueeze(-1))        # [B, C, L, d]
        slot_vec = slot_embed.weight.view(1, F, 1, d)
        x = self.norm(v + slot_vec)                     # [B, C, L, d]

        # ② 시간을 배치로 풀어 Cross-Attention
        x = x.permute(0, 2, 1, 3).contiguous()          # [B, L, C, d]
        x_flat = x.view(B * L, F, d)                    # [B*L, C, d]

        q = self.queries.unsqueeze(0).expand(B * L, -1, -1)   # [B*L, Q, d]
        z, _ = self.attn(q, x_flat, x_flat, need_weights=False)  # [B*L, Q, d]
        z = self.drop(z)

        # ③ side channel(mean/max) 보강
        if isinstance(self.fuse, nn.Sequential):
            mean = x_flat.mean(1, keepdim=True)         # [B*L, 1, d]
            mx   = x_flat.max(1, keepdim=True).values   # [B*L, 1, d]
            z    = self.fuse(torch.cat([z, mean.expand_as(z), mx.expand_as(z)], dim=-1))

        # ④ Q축 결합 (Q=1이면 squeeze)
        z = z.squeeze(1)                                # [B*L, d]
        z = z.view(B, L, d)                             # [B, L, d]

        return z


class QAPBackboneWrapper(nn.Module):
    """
    Wrapper that combines QAP module with backbone model
    Inspired by TimeSeriesPreprocessor for better time feature fusion
    """
    def __init__(self, backbone: nn.Module, qap_module: QAP_ServerModule,
                 F_client: int, time_dim: Optional[int] = None, fuse: str = "cat"):
        super().__init__()
        assert fuse in ("add", "cat"), f"fuse must be 'add' or 'cat', got {fuse}"

        self.qap = qap_module
        self.backbone = backbone
        self.d_model = qap_module.d_model
        self.fuse = fuse

        # Client-specific slot embedding (not shared)
        self.slot_embed = nn.Embedding(F_client, qap_module.d_model)

        # Time feature processing (like TimeSeriesPreprocessor)
        if time_dim is not None:
            # Time feature projection to d_model
            self.time_proj = nn.Linear(time_dim, self.d_model)

            # Fusion layer for concatenation
            if fuse == "cat":
                self.cat_proj = nn.Linear(self.d_model * 2, self.d_model)
            else:
                self.cat_proj = None
        else:
            self.time_proj = None
            self.cat_proj = None

    def forward(self, x: torch.Tensor, time_features: Optional[torch.Tensor] = None):
        """
        x: [B, L, C] raw input
        time_features: [B, L, t_dim] optional time features
        """
        # Apply QAP transformation
        qap_out = self.qap(x, self.slot_embed)  # [B, L, d_model]

        # Time feature integration (TimeSeriesPreprocessor style)
        if time_features is not None and self.time_proj is not None:
            # Project time features to d_model
            time_emb = self.time_proj(time_features)  # [B, L, d_model]

            # Fusion
            if self.fuse == "add":
                qap_out = qap_out + time_emb
            else:  # cat
                combined = torch.cat([qap_out, time_emb], dim=-1)  # [B, L, 2*d_model]
                qap_out = self.cat_proj(combined)  # [B, L, d_model]

        # Pass through backbone
        return self.backbone(qap_out)


# ──────────────────────────────────────────────────────────────────────────────
# Utility: infer_dims
# ──────────────────────────────────────────────────────────────────────────────

def infer_dims(loader: DataLoader) -> Tuple[int, int]:
    """
    Return `(input_size, out_dim)` inferred from the first batch of `loader`.

    Supports both of:
      - (x, y)
      - (x, y, time_feature)

    x: [B, L, C]  → input_size = C
    y: shape can be [B, L, C] (multivariate) or [B, L] / [B] (scalar)
       → out_dim = (y.shape[-1] if y.ndim >= 2 else 1)
    """
    batch = next(iter(loader))

    # Unpack safely
    if isinstance(batch, (list, tuple)):
        if len(batch) == 3:
            x, y, _tf = batch
        elif len(batch) == 2:
            x, y = batch
        else:
            raise ValueError(f"Unexpected batch tuple length: {len(batch)}")
    else:
        raise ValueError("Expected batch as (x,y) or (x,y,time_feature).")

    # Input size from x
    if not isinstance(x, torch.Tensor):
        x = torch.as_tensor(x)
    if x.ndim < 2:
        raise ValueError(f"x must be at least 2D, got {x.shape}")
    input_size = int(x.shape[-1])  # C

    # Output dim from y (fallback to C if ambiguous)
    if not isinstance(y, torch.Tensor):
        y = torch.as_tensor(y)

    if y.ndim == 0:
        out_dim = 1
    elif y.ndim == 1:
        out_dim = 1
    else:
        out_dim = int(y.shape[-1]) if y.shape[-1] > 1 else 1

    # If y is same shape as x feature-wise (multivariate forecasting),
    # out_dim should match C.
    if y.ndim >= 2 and y.shape[-1] == input_size:
        out_dim = input_size

    return input_size, out_dim


# ──────────────────────────────────────────────────────────────────────────────
# Public factory APIs
# ──────────────────────────────────────────────────────────────────────────────

def build_backbone(
    name: str,
    *,
    dataloader: Optional[DataLoader] = None,
    input_size: Optional[int] = None,
    hidden_size: int = 128,
    num_steps: int = 4,
    max_length: int = 100,
    **kw,
) -> nn.Module:
    """Simplified backbone creator for preprocessed data.

    Creates simple backbone without QAP wrapping since QAP is now handled in DataFactory.
    """
    if input_size is None:
        if dataloader is None:
            raise ValueError("Either input_size or dataloader is required.")
        input_size, _ = infer_dims(dataloader)

    # Build backbone model that expects preprocessed d_model input
    backbone = create_backbone_model(
        model_name=name,
        input_size=input_size,  # Now expects d_model from preprocessing
        hidden_size=hidden_size,
        num_steps=num_steps,
        max_length=max_length,
        **kw,
    )

    return backbone


# ---------------------------------------------------------------------------
__all__ = [
    "SimpleProjectionModule",
    "ProjectionBackboneWrapper",
    "QAP_ServerModule",
    "QAPBackboneWrapper",
    "infer_dims",
    "build_backbone",
]