# data_factory.py
"""Data Factory: Split universal time series datasets by client"""

from typing import List, Tuple, Dict, Optional
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from pandas.tseries.frequencies import to_offset
from common.function import time_features

# Import data loading functionality
from .data_loader import load_timeseries_data


# ============================================================
# Time Series Dataset Classes
# ============================================================

class SplitData:
    """
    Load time series CSV/TXT/HDF5 files and
    considering window/horizon, generate:
     - self.raw_data: (n, m) pure numerical data
     - self.time_feats: (n, t) time features
    and perform train/valid/test split.

    __getitem__ return value:
     - X: (window, m + t)
     - y: (horizon, m + t)
    """

    def __init__(
        self,
        file: str,
        window: int,
        horizon: int,
        train_ratio: float,
        test_ratio: float,
        dataset_name: str,
        normalize: bool = False,  # normalization switch
        scaler: Optional[object] = None,  # normalization tool (MinMaxScaler, StandardScaler etc.)
        scaler_type: str = "standard",  # scaler type ("minmax" or "standard")
    ):
        # 1) save parameters
        self.window       = window
        self.horizon      = horizon
        self.train_ratio  = train_ratio
        self.test_ratio   = test_ratio
        self.dataset_name = dataset_name  # "train"/"valid"/"test"
        self.normalize    = normalize
        self.scaler       = scaler # only raw_data is normalized, so receive single scaler
        self.scaler_type = scaler_type

        # 2) load file using data_loader function
        raw, dates_df = load_timeseries_data(file)
        self.raw_data = raw.astype(np.float32)
        self.n, self.m = self.raw_data.shape
        self.dates_df = dates_df

        # 3) calculate time features (once only)
        if dates_df is not None:
            idx = pd.DatetimeIndex(dates_df["date"])
            inf = pd.infer_freq(idx) or "H"
            freq = inf.lower()
            # use user-defined time_features function
            tf = time_features(dates_df, freq=freq)       # (n, t)
            self.time_feats = tf.astype(np.float32)
        else:
            self.time_feats = None

        # 4) perform normalization
        if self.normalize:
            if self.scaler is None:
                # fit scaler only on train split
                if self.dataset_name == "train":
                    if self.scaler_type == "minmax":
                        self.scaler = MinMaxScaler()
                    elif self.scaler_type == "standard":
                        self.scaler = StandardScaler()
                    else:
                        raise ValueError("scaler_type must be 'minmax' or 'standard'.")

                    # fit scaler with train data
                    self.raw_data = self.scaler.fit_transform(self.raw_data)
                else:
                    raise ValueError(f"scaler must be provided for {self.dataset_name} split. (pass scaler fitted on train)")
            else:
                # val/test splits only transform with scaler fitted on train
                self.raw_data = self.scaler.transform(self.raw_data)

        # 5) calculate train/valid/test indices
        self._split(self.train_ratio, self.test_ratio, self.dataset_name)

    def _split(self, tr, te, ds):
        total = self.n - self.window - self.horizon + 1
        train_n = int(total * tr)
        test_n  = int(total * te)
        valid_n = total - train_n - test_n

        if ds == "train":
            self.start_idx, self.length = 0, train_n
        elif ds in ("valid", "val"):
            self.start_idx, self.length = train_n, valid_n
        elif ds == "test":
            self.start_idx, self.length = train_n + valid_n, test_n
        else:
            raise ValueError("dataset_name must be one of train/valid/val/test")

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        real_i = i + self.start_idx

        # 1) raw X (window, m) - already normalized or original state (when normalize=True)
        X = self.raw_data[real_i : real_i + self.window]

        # 2) time X (window, t) - assumed already normalized → concat → (window, m + t)
        if self.time_feats is not None:
            tX = self.time_feats[real_i : real_i + self.window]
            X = np.concatenate([X, tX], axis=1)

        # 3) y_raw: (horizon, m) - remove time features, use only pure target values
        # y is always taken from raw_data (time_feats are not included in y)
        # raw_data is already normalized or original state (when normalize=True)
        y = self.raw_data[
            real_i + self.window : real_i + self.window + self.horizon
        ]

        return X.astype(np.float32), y.astype(np.float32)

    def get_index(self):
        # self.start_idx, self.length are the data range of current SplitData
        return self.dates_df['date'].iloc[self.start_idx : self.start_idx + self.length].reset_index(drop=True)

    def inverse_transform(self, data):
        if self.scaler is not None:
            return self.scaler.inverse_transform(data)
        else:
            return data  # return as is if not normalized


class MultiClientDataset(SplitData):
    """Split single time series data across multiple clients"""

    def __init__(
        self,
        file: str = None,  # 파일 경로 (optional)
        raw_data: np.ndarray = None,  # raw data 직접 전달
        dates_df: pd.DataFrame = None,  # dates 직접 전달
        window: int = None,
        horizon: int = None,
        train_ratio: float = None,
        test_ratio: float = None,
        dataset_name: str = None,
        client_features: List[int] = None,  # 이 클라이언트에 할당된 feature 인덱스
        normalize: bool = False,
        scaler: Optional[object] = None,
        scaler_type: str = "standard",
    ):
        if file is not None:
            # 기존 파일 기반 방식
            super().__init__(
                file=file,
                window=window,
                horizon=horizon,
                train_ratio=train_ratio,
                test_ratio=test_ratio,
                dataset_name=dataset_name,
                normalize=normalize,
                scaler=scaler,
                scaler_type=scaler_type,
            )
        else:
            # 직접 데이터 전달 방식
            if raw_data is None or dates_df is None:
                raise ValueError("Either file or (raw_data, dates_df) must be provided")

            # SplitData 초기화 매개변수 설정
            self.window = window
            self.horizon = horizon
            self.train_ratio = train_ratio
            self.test_ratio = test_ratio
            self.dataset_name = dataset_name
            self.normalize = normalize
            self.scaler = scaler
            self.scaler_type = scaler_type

            # 데이터 설정
            self.raw_data = raw_data.astype(np.float32)
            self.n, self.m = self.raw_data.shape
            self.dates_df = dates_df

            # 시간 특성 계산
            if dates_df is not None:
                idx = pd.DatetimeIndex(dates_df["date"])
                inf = pd.infer_freq(idx) or "H"
                freq = inf.lower()
                tf = time_features(dates_df, freq=freq)
                self.time_feats = tf.astype(np.float32)
            else:
                self.time_feats = None

            # 정규화 수행
            if self.normalize:
                if self.scaler is None:
                    if self.dataset_name == "train":
                        if self.scaler_type == "minmax":
                            self.scaler = MinMaxScaler()
                        elif self.scaler_type == "standard":
                            self.scaler = StandardScaler()
                        else:
                            raise ValueError("scaler_type must be 'minmax' or 'standard'.")
                        self.raw_data = self.scaler.fit_transform(self.raw_data)
                    else:
                        raise ValueError(f"scaler must be provided for {self.dataset_name} split. (pass scaler fitted on train)")
                else:
                    self.raw_data = self.scaler.transform(self.raw_data)

            # train/valid/test 인덱스 계산
            self._split(self.train_ratio, self.test_ratio, self.dataset_name)

        self.client_features = client_features
        self.time_feat_dim = self.time_feats.shape[1] if self.time_feats is not None else 0

    def __getitem__(self, i):
        X, y = super().__getitem__(i)

        # X에서 할당된 feature만 선택 (time features는 분리해서 반환)
        if self.time_feats is not None:
            # X = [raw_features, time_features] 구조
            raw_dim = self.m  # 원본 raw feature 차원
            X_raw = X[:, :raw_dim]  # raw features
            X_time = X[:, raw_dim:]  # time features

            # 할당된 raw feature만 선택
            X_raw_selected = X_raw[:, self.client_features]

            # raw features와 time features를 분리해서 반환
            return (X_raw_selected.astype(np.float32),
                   y[:, self.client_features].astype(np.float32),
                   X_time.astype(np.float32))
        else:
            # time features가 없는 경우
            X_client = X[:, self.client_features]
            y_client = y[:, self.client_features]
            return X_client.astype(np.float32), y_client.astype(np.float32)

    def get_feature_count(self):
        """이 클라이언트의 feature 개수 반환 (time features 제외)"""
        return len(self.client_features)


def assign_random_features(
    num_clients: int,
    total_features: int,
    max_features: int,
    feature_overlap: float = 0.0,
    seed: int = 42
) -> List[List[int]]:
    """Assign random feature indices to each client

    Each client receives 1~max_features features randomly.
    feature_overlap controls the degree of feature sharing between clients.
    """
    import random

    random.seed(seed)
    np.random.seed(seed)

    if max_features > total_features:
        max_features = total_features

    client_features = []

    # Assign features independently for each client
    for client_id in range(num_clients):
        # Number of features this client will have (random between 1~max_features)
        num_features = random.randint(1, max_features)

        if feature_overlap == 0.0:
            # No overlap: completely random selection
            client_feature_list = random.sample(range(total_features), num_features)
        else:
            # With overlap: share some features with previous clients
            if client_id == 0:
                # First client: completely random
                client_feature_list = random.sample(range(total_features), num_features)
            else:
                # Create feature pool from previous clients
                all_prev_features = set()
                for prev_features in client_features:
                    all_prev_features.update(prev_features)

                if len(all_prev_features) == 0:
                    client_feature_list = random.sample(range(total_features), num_features)
                else:
                    # Determine ratio of shared vs new features based on overlap ratio
                    num_shared = int(num_features * feature_overlap)
                    num_unique = num_features - num_shared

                    # Select shared features (from features held by previous clients)
                    shared_features = random.sample(
                        list(all_prev_features),
                        min(num_shared, len(all_prev_features))
                    )

                    # Select remaining features randomly from all features
                    remaining_needed = num_features - len(shared_features)
                    if remaining_needed > 0:
                        available_features = list(range(total_features))
                        unique_features = random.sample(available_features, remaining_needed)
                        client_feature_list = shared_features + unique_features
                    else:
                        client_feature_list = shared_features

        # Remove duplicates and sort
        client_feature_list = sorted(list(set(client_feature_list)))
        client_features.append(client_feature_list)

    return client_features


# ============================================================
# QAP Local Slot Classes
# ============================================================

class QAP_LocalSlot(nn.Module):
    """
    Input: x [B, L, C]  (BLC only)
    Output:
      - out_scalar=False → [B, L, d_model]
      - out_scalar=True  → [B, L]
    Local parameters: slot_embed  (FedPer)
    Server shared: value_proj, queries, attn, fuse, norm, out_proj (FedAvg)
    """
    def __init__(self, F_client: int, d_model: int = 128,
                 num_heads: int = 8, num_queries: int = 1,
                 use_side_channel: bool = True, out_scalar: bool = False):
        super().__init__()

        # ── Local (per client) ──
        self.slot_embed = nn.Embedding(F_client, d_model)

        # ── Shared (server) ──
        self.value_proj = nn.Linear(1, d_model, bias=False)
        self.queries    = nn.Parameter(torch.randn(num_queries, d_model) / (d_model ** 0.5))
        self.attn       = nn.MultiheadAttention(d_model, num_heads, batch_first=True)

        if use_side_channel:
            self.fuse = nn.Sequential(
                nn.Linear(d_model * 3, d_model), nn.GELU(), nn.Linear(d_model, d_model)
            )
        else:
            self.fuse = nn.Identity()

        self.out_scalar = out_scalar
        if out_scalar:
            self.out_proj = nn.Linear(d_model, 1)

        self.norm = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(0.1)

    def forward(self, x_blc: torch.Tensor):      # x: [B, L, C]
        # 0) BLC -> BCL로 맞춘 뒤 기존 로직 사용
        x_bcl = x_blc.permute(0, 2, 1).contiguous()     # [B, C, L]
        B, F, L = x_bcl.shape
        d = self.value_proj.out_features

        # ① 값 투영 + 로컬 slot 임베딩
        v = self.value_proj(x_bcl.unsqueeze(-1))        # [B, C, L, d]
        slot_vec = self.slot_embed.weight.view(1, F, 1, d)
        x = self.norm(v + slot_vec)                     # [B, C, L, d]

        # ② 시간을 배치로 풀어 Cross-Attention
        x = x.permute(0, 2, 1, 3).contiguous()          # [B, L, C, d]
        x_flat = x.view(B * L, F, d)                    # [B*L, C, d]

        q = self.queries.unsqueeze(0).expand(B * L, -1, -1)   # [B*L, Q, d]
        z, _ = self.attn(q, x_flat, x_flat, need_weights=False)  # [B*L, Q, d]
        z = self.drop(z)

        # ③ side channel(mean/max) 보강
        if isinstance(self.fuse, nn.Sequential):
            mean = x_flat.mean(1, keepdim=True)         # [B*L, 1, d]
            mx   = x_flat.max(1, keepdim=True).values   # [B*L, 1, d]
            z    = self.fuse(torch.cat([z, mean.expand_as(z), mx.expand_as(z)], dim=-1))

        # ④ Q축 결합 (Q=1이면 squeeze)
        z = z.squeeze(1)                                # [B*L, d]
        z = z.view(B, L, d)                             # [B, L, d]

        # ⑤ Scalar output if needed
        if self.out_scalar:
            z = self.out_proj(z).squeeze(-1)            # [B, L]

        return z


class TimeSeriesPreprocessor(nn.Module):
    """
    Integrated time series preprocessor: QAP embedding + Time feature fusion

    Pipeline:
    1. Raw data (B,W,C) → QAP embedding (B,W,d)
    2. Time features (B,W,t_dim) → Linear projection (B,W,d)
    3. Fusion (add/cat) → Final embedding (B,W,d)
    """
    def __init__(self, F_client: int, d_model: int, t_dim: int,
                 num_heads: int = 8, num_queries: int = 1,
                 use_side_channel: bool = True, fuse: str = "cat"):
        super().__init__()
        assert fuse in ("add", "cat")

        # QAP module (shared parameters: backbone.*)
        self.backbone = QAP_LocalSlot(
            F_client=F_client, d_model=d_model,
            num_heads=num_heads, num_queries=num_queries,
            use_side_channel=use_side_channel, out_scalar=False
        )

        # Time feature projection
        self.t_proj = nn.Linear(t_dim, d_model)

        # Fusion method
        self.fuse = fuse
        if fuse == "cat":
            self.cat_proj = nn.Linear(d_model * 2, d_model)

    def forward(self, x_blc: torch.Tensor, time_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x_blc: Raw time series [B,L,C]
            time_features: Time features [B,L,t_dim]
        Returns:
            Preprocessed embeddings [B,L,d_model]
        """
        # QAP embedding
        z = self.backbone(x_blc)                    # [B,L,d]

        # Time feature projection
        te = self.t_proj(time_features)             # [B,L,d]

        # Fusion
        if self.fuse == "add":
            return z + te
        else:  # cat
            z = torch.cat([z, te], dim=-1)          # [B,L,2d]
            return self.cat_proj(z)                 # [B,L,d]


class TimeSeriesDataset(Dataset):
    """
    Integrated time series dataset: raw data + time features → preprocessing options

    Output by mode:
    - raw mode: (x_raw:[W,C], y_raw:[H,C], time_features:[W,t_dim])
    - preprocessed mode: (x_emb:[W,d], y_raw:[H,C])
    """
    def __init__(self, base_ds: MultiClientDataset, timeenc: int = 1, time_freq: str = None,
                 preprocessor: TimeSeriesPreprocessor = None, device: str = "cpu"):
        self.base = base_ds
        self.window_len = int(base_ds.window)
        self.preprocessor = preprocessor
        self.device = torch.device(device)

        # Time features 사전 계산
        dt = pd.to_datetime(self.base.dates_df["date"])
        dt_df = pd.DataFrame({"date": dt})
        if time_freq is None:
            diffs = dt.diff().dropna()
            time_freq = to_offset(diffs.mode().iloc[0]).freqstr if len(diffs) > 0 else "H"
        self.time_freq = time_freq

        tf_np = time_features(dt_df, timeenc=timeenc, freq=time_freq)
        self.tf_full = torch.tensor(tf_np, dtype=torch.float32)
        self.t_dim = self.tf_full.shape[1]
        self.feature_count = int(getattr(self.base, "feature_count", 0))

        # Preprocessor 설정 시 eval 모드로
        if self.preprocessor is not None:
            self.preprocessor.to(self.device).eval()

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx: int):
        # Base data 가져오기
        base_item = self.base[idx]

        if len(base_item) == 3:
            # time features가 있는 경우: (x_raw, y_raw, time_features)
            x_wc, y_hc, tf_wt = base_item
            tf_wt = torch.tensor(tf_wt, dtype=torch.float32)
        else:
            # time features가 없는 경우: (x_raw, y_raw)
            x_wc, y_hc = base_item
            # Time features 가져오기
            s, e = idx, idx + self.window_len
            tf_wt = self.tf_full[s:e]
            if len(tf_wt) < self.window_len and len(tf_wt) > 0:
                last = tf_wt[-1:, :]
                tf_wt = torch.cat([tf_wt, last.repeat(self.window_len - len(tf_wt), 1)], dim=0)

        # 텐서 변환
        x_wc = torch.tensor(x_wc, dtype=torch.float32)
        y_hc = torch.tensor(y_hc, dtype=torch.float32)

        # 전처리 없이 raw 데이터 반환
        if self.preprocessor is None:
            return x_wc, y_hc, tf_wt                  # [W,C], [H,C], [W,t_dim]

        # 전처리 적용
        x_wc = x_wc.unsqueeze(0).to(self.device)      # [1,W,C]
        y_hc = y_hc.unsqueeze(0).to(self.device)      # [1,H,C]
        tf_wt = tf_wt.unsqueeze(0).to(self.device)    # [1,W,t]

        with torch.no_grad():
            x_emb = self.preprocessor(x_wc, tf_wt).squeeze(0)  # [W,d]
            y_raw = y_hc.squeeze(0)                            # [H,C]

        return x_emb.cpu(), y_raw.cpu()               # [W,d], [H,C]


# ============================================================
# DataFactory: 범용 시계열 데이터셋 생성 팩토리
# ============================================================
class DataFactory:
    def __init__(
        self,
        dataset_file: str,  # 범용 시계열 데이터 파일 경로 (필수)
        num_clients: int,  # 클라이언트 수 (필수)
        max_features: int = None,  # 클라이언트당 최대 feature 수
        feature_overlap: float = 0.0,  # 클라이언트 간 feature 중복 비율
        split_ratios: Tuple[float, float, float] = (0.8, 0.1, 0.1),
        scale_y: bool = True,
        seed: int = 42,
        universal_scaler = None,  # train에서 fit한 scaler 전달용
    ):
        self.split_ratios = split_ratios
        self.scale_y = scale_y
        self.seed = seed
        self.dataset_file = dataset_file
        self.num_clients = num_clients
        self.max_features = max_features
        self.feature_overlap = feature_overlap

        # 기존 scaler 전달받으면 저장
        if universal_scaler is not None:
            self._universal_scaler = universal_scaler

        # 데이터 로드 및 클라이언트별 feature 할당
        raw_data, dates_df = load_timeseries_data(dataset_file)
        total_features = raw_data.shape[1]

        if max_features is None:
            max_features = total_features

        self.client_features = assign_random_features(
            num_clients=num_clients,
            total_features=total_features,
            max_features=max_features,
            feature_overlap=feature_overlap,
            seed=seed
        )

        # 전체 데이터 정보 저장
        self.raw_data = raw_data
        self.dates_df = dates_df
        self.total_features = total_features

        print(f"[DataFactory] Universal dataset: {dataset_file}")
        print(f"[DataFactory] {num_clients} clients, {total_features} total features, max_features={max_features}")
        print(f"[DataFactory] Feature overlap: {feature_overlap}, split={self.split_ratios}")

        # 캐시
        self._split_cache: Dict[str, Dict[str, pd.DataFrame]] = {}

    # --- 공개 API ---
    def get_client_info(self, client_id: str = None) -> Dict[str, any]:
        """Get client information by client_id or return all clients info"""
        if client_id is None:
            # Return info for all clients
            all_info = {}
            for i in range(self.num_clients):
                client_name = f"client_{i}"
                all_info[client_name] = {
                    'client_id': client_name,
                    'feature_count': len(self.client_features[i]),
                    'feature_indices': self.client_features[i],
                    'total_features': self.total_features,
                    'data_shape': self.raw_data.shape,
                    'dataset_file': self.dataset_file
                }
            return all_info
        else:
            # Find specific client
            if client_id.startswith("client_"):
                try:
                    client_idx = int(client_id.split("_")[1])
                    if 0 <= client_idx < self.num_clients:
                        return {
                            'client_id': client_id,
                            'feature_count': len(self.client_features[client_idx]),
                            'feature_indices': self.client_features[client_idx],
                            'total_features': self.total_features,
                            'data_shape': self.raw_data.shape,
                            'dataset_file': self.dataset_file
                        }
                except (IndexError, ValueError):
                    pass
            raise ValueError(f"Client {client_id} not found")

    def get_client_scaler(self, client_idx: int = None):
        """
        Get scaler for specific client

        Args:
            client_idx: Client index (0-based). Currently ignored as all clients
                        use the same universal scaler.

        Returns:
            StandardScaler: Universal scaler fitted on all training data
        """
        # Universal scaler 사용 (모든 클라이언트 동일)
        if hasattr(self, '_universal_scaler'):
            return self._universal_scaler
        else:
            raise ValueError("Universal scaler not initialized. Create train datasets first.")

    def fit_scaler(self, scaler_type: str = "standard") -> object:
        """Fit scaler on the full raw data and return it"""
        if scaler_type == "standard":
            scaler = StandardScaler()
        elif scaler_type == "minmax":
            scaler = MinMaxScaler()
        else:
            raise ValueError("scaler_type must be 'standard' or 'minmax'")

        scaler.fit(self.raw_data)
        self._universal_scaler = scaler
        return scaler

    def transform_data(self, data: np.ndarray, scaler: object = None) -> np.ndarray:
        """Transform data using the provided scaler or universal scaler"""
        if scaler is None:
            if hasattr(self, '_universal_scaler'):
                scaler = self._universal_scaler
            else:
                raise ValueError("No scaler provided and no universal scaler available")

        return scaler.transform(data)

    def inverse_transform_data(self, data: np.ndarray, scaler: object = None) -> np.ndarray:
        """Inverse transform data using the provided scaler or universal scaler"""
        if scaler is None:
            if hasattr(self, '_universal_scaler'):
                scaler = self._universal_scaler
            else:
                raise ValueError("No scaler provided and no universal scaler available")

        return scaler.inverse_transform(data)

    def create_client_datasets(
        self,
        split: str,
        window_len: int,
        *,
        horizon: int = 0,          # ★ 추가
    ) -> List[Tuple[str, any]]:
        assert split in ("train", "val", "test")
        datasets: List[Tuple[str, any]] = []

        # MultiClientDataset 사용 (직접 데이터 전달)

        # dates_df가 없으면 더미 날짜 생성
        dates_df = self.dates_df
        if dates_df is None:
            dummy_dates = pd.date_range('2020-01-01', periods=len(self.raw_data), freq='H')
            dates_df = pd.DataFrame({'date': dummy_dates})

        # scaler 처리
        scaler_to_use = None

        if split == "train":
            # train split: 새로 fit 또는 기존 scaler 사용
            if not hasattr(self, '_universal_scaler'):
                # 새로 scaler 생성 (첫 번째 클라이언트로)
                first_client_features = self.client_features[0]
                temp_dataset = MultiClientDataset(
                    raw_data=self.raw_data,
                    dates_df=dates_df,
                    window=window_len,
                    horizon=horizon,
                    train_ratio=self.split_ratios[0],
                    test_ratio=self.split_ratios[2],
                    dataset_name=split,
                    client_features=first_client_features,
                    normalize=True,
                    scaler=None,
                    scaler_type="standard"
                )
                self._universal_scaler = temp_dataset.scaler
                print(f"[DataFactory] Created universal scaler from train data")
            scaler_to_use = None  # train에서는 각 dataset이 새로 fit
        else:
            # val/test split: 반드시 기존 scaler 사용
            if hasattr(self, '_universal_scaler'):
                scaler_to_use = self._universal_scaler
                print(f"[DataFactory] Using existing universal scaler for {split} split")
            else:
                raise ValueError(f"No universal scaler available for {split} split. Create train dataset first.")

        # 클라이언트별 데이터셋 생성
        for client_id in range(self.num_clients):
            try:
                client_features = self.client_features[client_id]
                if len(client_features) == 0:
                    print(f"[DataFactory] Skip client {client_id}: no features assigned")
                    continue

                dataset = MultiClientDataset(
                    raw_data=self.raw_data,
                    dates_df=dates_df,
                    window=window_len,
                    horizon=horizon,
                    train_ratio=self.split_ratios[0],
                    test_ratio=self.split_ratios[2],
                    dataset_name=split,
                    client_features=client_features,
                    normalize=True,
                    scaler=scaler_to_use,
                    scaler_type="standard"
                )

                datasets.append((f"client_{client_id}", dataset))
            except Exception as e:
                print(f"[DataFactory] Error creating dataset for client {client_id}: {e}")

        return datasets

    def create_aligned_datasets(self, split: str, window_len: int, horizon: int,
                               alignment_method: str = "qap", **alignment_kwargs):
        """
        Create QAP-aligned datasets (simplified - only QAP supported)

        Args:
            split: "train", "val", "test"
            window_len: Window length
            horizon: Prediction horizon
            alignment_method: Always "qap" (other methods removed)

        Returns:
            List of datasets with QAP alignment
        """
        # Only QAP alignment is supported now
        return self.create_client_datasets(
            split=split,
            window_len=window_len,
            horizon=horizon
        )

    def create_client_dataloaders(
        self,
        split: str,
        window_len: int,
        batch_size: int = 32,
        num_workers: int = 0,
        *,
        horizon: int = 0,           # ★ 추가
    ) -> List[DataLoader]:
        """클라이언트별 DataLoader 생성 (같은 클라 내 F는 고정)."""
        datasets = self.create_client_datasets(
            split=split, window_len=window_len, horizon=horizon
        )
        loaders: List[DataLoader] = []
        for client_id, ds in datasets:  # Now expecting (client_id, dataset) tuples
            loaders.append(
                DataLoader(
                    ds,
                    batch_size=batch_size,
                    shuffle=(split == "train"),
                    drop_last=(split == "train"),
                    num_workers=num_workers,
                    pin_memory=torch.cuda.is_available(),  # Only if GPU available
                    persistent_workers=True if num_workers > 0 else False,  # FAST: Worker 재사용
                )
            )
        return loaders

    def get_max_features(self) -> int:
        """Get maximum number of features across all clients"""
        return max(len(features) for features in self.client_features)

    def create_qap_preprocessor(self, client_id: int, d_model: int = 64,
                               num_heads: int = 8, num_queries: int = 1,
                               use_side_channel: bool = True, fuse: str = "cat") -> QAP_LocalSlot:
        """Create QAP preprocessor for a specific client"""
        if client_id >= self.num_clients:
            raise ValueError(f"Client ID {client_id} exceeds number of clients {self.num_clients}")

        F_client = len(self.client_features[client_id])
        # Assume time features dimension (can be calculated from dates_df if needed)
        t_dim = 4  # Common time features: minute, hour, day, month

        if self.dates_df is not None:
            # Calculate actual time features dimension
            tf = time_features(self.dates_df, freq="h")
            t_dim = tf.shape[1]

        return QAP_LocalSlot(
            F_client=F_client,
            d_model=d_model,
            num_heads=num_heads,
            num_queries=num_queries,
            use_side_channel=use_side_channel,
            out_scalar=False
        )

    def create_qap_module(self, client_id: int, d_model: int = 64,
                         num_heads: int = 8, num_queries: int = 1,
                         use_side_channel: bool = True, out_scalar: bool = False) -> QAP_LocalSlot:
        """Create QAP module for a specific client"""
        if client_id >= self.num_clients:
            raise ValueError(f"Client ID {client_id} exceeds number of clients {self.num_clients}")

        F_client = len(self.client_features[client_id])

        return QAP_LocalSlot(
            F_client=F_client,
            d_model=d_model,
            num_heads=num_heads,
            num_queries=num_queries,
            use_side_channel=use_side_channel,
            out_scalar=out_scalar
        )

    def create_preprocessed_datasets(
        self,
        split: str,
        window_len: int,
        horizon: int,
        d_model: int = 64,
        timeenc: int = 1,
        time_freq: str = None,
        num_heads: int = 8,
        num_queries: int = 1,
        use_side_channel: bool = True,
        fuse: str = "cat",
        device: str = "cpu"
    ) -> List[Tuple[str, TimeSeriesDataset]]:
        """
        Create preprocessed datasets with QAP preprocessing applied
        Returns: List of (client_id, TimeSeriesDataset) with preprocessed output (x_emb, y_raw)
        """
        import time
        start_time = time.time()

        # Create base datasets
        base_datasets = self.create_client_datasets(
            split=split, window_len=window_len, horizon=horizon
        )

        # Create preprocessed datasets
        preprocessed_datasets = []
        for client_id, base_ds in base_datasets:
            client_idx = int(client_id.split("_")[1])
            F_client = len(self.client_features[client_idx])

            # Calculate time feature dimensions
            if self.dates_df is not None:
                tf = time_features(self.dates_df, freq=time_freq or "h")
                t_dim = tf.shape[1]
            else:
                t_dim = 4  # default

            # Create preprocessor for this client
            preprocessor = TimeSeriesPreprocessor(
                F_client=F_client,
                d_model=d_model,
                t_dim=t_dim,
                num_heads=num_heads,
                num_queries=num_queries,
                use_side_channel=use_side_channel,
                fuse=fuse
            )

            # Create final dataset with preprocessor
            final_ds = TimeSeriesDataset(
                base_ds, timeenc=timeenc, time_freq=time_freq,
                preprocessor=preprocessor, device=device
            )

            preprocessed_datasets.append((client_id, final_ds))

        print(f"Created {len(preprocessed_datasets)} preprocessed datasets ({time.time()-start_time:.1f}s)")
        return preprocessed_datasets

    def create_preprocessed_dataloaders(
        self,
        split: str,
        window_len: int,
        horizon: int,
        batch_size: int = 32,
        num_workers: int = 0,
        d_model: int = 64,
        timeenc: int = 1,
        time_freq: str = None,
        num_heads: int = 8,
        num_queries: int = 1,
        use_side_channel: bool = True,
        fuse: str = "cat",
        device: str = "cpu"
    ) -> List[Tuple[str, DataLoader]]:
        """
        Create preprocessed dataloaders with QAP preprocessing applied
        Returns: List of (client_id, DataLoader) with preprocessed output (x_emb, y_raw)
        """
        datasets = self.create_preprocessed_datasets(
            split=split, window_len=window_len, horizon=horizon,
            d_model=d_model, timeenc=timeenc, time_freq=time_freq,
            num_heads=num_heads, num_queries=num_queries,
            use_side_channel=use_side_channel, fuse=fuse, device=device
        )

        loaders = []
        for client_id, dataset in datasets:
            loader = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=(split == "train"),
                drop_last=(split == "train"),
                num_workers=num_workers,
                pin_memory=torch.cuda.is_available(),
                persistent_workers=True if num_workers > 0 else False
            )
            loaders.append((client_id, loader))

        return loaders
