

# Federated / training.py – sequential‑client FederatedTrainer (FedPer, fast-eval, ckpt-name fixed)
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Tuple
from datetime import datetime
import time, json, random, logging, pickle
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# AMP compatibility
try:
    from torch.amp import GradScaler, autocast
    AMP_MODE = "amp"
except ImportError:
    from torch.cuda.amp import GradScaler, autocast
    AMP_MODE = "cuda"

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# --- Project utils/server/metrics ---
from federated.server import Server
from common.function import single_mse as mse, single_mae as mae, r2_score, cvrmse

DISPLAY_ORDER = ("loss", "r2", "mae", "cvrmse")
def printt(msg: str): print(msg, flush=True)
def fmt(m: dict) -> str:
    return (f"loss:{m['loss']:8.4f} | r2:{m['r2']:7.3f} | "
            f"mae:{m['mae']:8.4f} | cvrmse:{m['cvrmse']:8.4f}")

class FederatedTrainer:
    def __init__(
        self,
        args,
        server: Server,
        clients: List[nn.Module],                  # Each client model (ForecastWrapper or backbone+head)
        loaders: Dict[int, Dict[str, DataLoader]], # {cid: {"train":..., "valid":..., "test":...}}
        *,
        data_factory=None,         # DataFactory for inverse transform
        device: str = "cuda",
        global_epochs: int = 50,
        local_epochs: int = 1,
        lr: float = 1e-3,
        optimizer: str = "adam",
        frac: float = 1.0,         # Participation ratio per round (1.0 = all clients)
        experiment: str = "exp",
        seed: int = 0,
        resume: bool = False,
        early_stop: bool = True,
        patience: int = 5,
        grad_accum: int = 1,
        fed_algorithm: str = "fedper",
        fed_algorithm_config: dict = None,
    ) -> None:
        # logger
        self.logger = logging.getLogger("Trainer")
        if not self.logger.handlers:
            h = logging.StreamHandler()
            h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s │ %(message)s", "%m%d %H:%M:%S"))
            self.logger.addHandler(h)
        self.logger.setLevel(logging.INFO)

        # seeds
        random.seed(seed); np.random.seed(seed)
        torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

        # attrs
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.G, self.L = int(global_epochs), int(local_epochs)
        self.lr, self.opt_name = float(lr), optimizer.lower()
        self.frac = min(max(frac, 0.), 1.)
        self.early_stop, self.patience = bool(early_stop), int(patience)
        self.grad_accum = max(1, int(grad_accum))

        base = Path("output")/experiment
        self.ckpt_dir = base/"checkpoints"; self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.tb_dir   = base/"logs";        self.tb_dir.mkdir(parents=True, exist_ok=True)
        self.res_dir  = base/"results";     self.res_dir.mkdir(parents=True, exist_ok=True)
        self.writer = SummaryWriter(self.tb_dir.as_posix())

        # models & data
        self.server  = server  # Server doesn't need device placement
        self.clients = {i: c.to(self.device) for i, c in enumerate(clients)}
        self.loaders = loaders
        self.data_factory = data_factory  # Store data_factory for inverse transform

        # trackers
        self.start_ep = 0
        self.best_global = float("inf")
        self.no_improve = 0
        self.best_client: Dict[int, float] = defaultdict(lambda: float("inf"))
        self.args = args
        self.local_resume_data = {}  # Local-only resume data storage

        # Store federated algorithm info
        self.fed_algorithm = fed_algorithm.lower()
        self.fed_algorithm_config = fed_algorithm_config or {}


        if resume:
            self._load_ckpt()


    def _fedper_aggregate(self, selected_clients, client_weights):
        """Apply FedPer aggregation using new Server class."""
        if not selected_clients:
            return

        # Collect client shared parameters
        client_params = []
        for client in selected_clients:
            # Get shared parameters from client
            if hasattr(client, 'state_dict_for_server'):
                client_state = client.state_dict_for_server()
            else:
                # Fallback: assume client is the model directly
                client_state = {k: v for k, v in client.state_dict().items()
                               if not (k.startswith("head") or k.endswith("slot_embed.weight"))}
            client_params.append(client_state)

        # Use server's aggregate_and_update method
        updated_params = self.server.aggregate_and_update(client_params)

        # Broadcast updated parameters back to clients
        for client in selected_clients:
            if hasattr(client, 'load_shared_state_dict'):
                client.load_shared_state_dict(updated_params)
            else:
                # Fallback: update client state dict directly
                current_sd = client.state_dict()
                current_sd.update(updated_params)
                client.load_state_dict(current_sd)

    # ============== Global Round Loop ==============
    def train(self):
        for ep in range(self.start_ep, self.G):
            t0 = time.time()
            printt(f"\n=== [Global Epoch {ep+1}/{self.G}] ===")

            # ① Selected client sampling & server→client broadcast (shared parameters only)
            sel = self._select_clients()
            printt(f"Selected clients: {sel}")
            # Broadcast shared parameters to selected clients
            global_params = self.server.broadcast_parameters()
            for cid in sel:
                if global_params:
                    current_sd = self.clients[cid].state_dict()
                    current_sd.update(global_params)
                    self.clients[cid].load_state_dict(current_sd)

            # ② Local training (sequential training of selected clients with for-loop) and metric collection
            client_metrics, client_weights = self._train_selected(sel, ep)

            # ③ Server aggregation (FedPer aggregation)
            selected_clients = [self.clients[cid] for cid in sel]
            self._fedper_aggregate(selected_clients, client_weights)

            # ④ Global metric calculation (weighted average of client metrics)
            agg = {}
            for split in ['train', 'valid', 'test']:
                agg[split] = self.aggregate_client_metrics(client_metrics[split], client_weights)

            # ⑤ Round log
            dt = time.time() - t0
            printt(f"Global ({dt:6.1f}s)")
            printt(f"{ep:02d} Global Train - {fmt(agg['train'])}")
            printt(f"{ep:02d} Global Valid - {fmt(agg['valid'])}")
            printt(f"{ep:02d} Global Test  - {fmt(agg['test'])}")

            # ⑥ Early stopping & checkpoint
            val_loss = agg["valid"]["loss"]

            # Best model check and save (performed before early_stop call)
            if val_loss < self.best_global:
                self.best_global = val_loss  # Update best_global value
                self.no_improve = 0  # Reset early stop counter
                torch.save(self.server.global_state, self.res_dir/"global_best.pth")

                # scalers 디렉토리 생성
                scaler_dir = self.res_dir / "scalers"
                scaler_dir.mkdir(parents=True, exist_ok=True)

                # Save all client models when global best is achieved
                for cid in self.clients.keys():
                    client_best_path = self.res_dir / f"client_{cid}_best.pth"

                    # 모델 재구성에 필요한 정보 수집
                    sample_batch = next(iter(self.loaders[cid]["train"]))
                    if len(sample_batch) == 3:
                        x_sample, y_sample, time_features = sample_batch
                        enable_time_features = time_features is not None
                        time_dim = time_features.shape[-1] if enable_time_features else None
                    else:
                        x_sample, y_sample = sample_batch
                        enable_time_features = False
                        time_dim = None

                    feature_count = x_sample.shape[-1]
                    out_dim = y_sample.shape[-1] if y_sample.ndim >= 3 else 1

                    torch.save({
                        # 모델 파라미터
                        'state_dict': self.clients[cid].state_dict(),

                        # 기본 정보
                        'model_name': self.args.model_name,
                        'window_len': self.args.window,
                        'horizon': self.args.horizon,
                        'feature_count': feature_count,
                        'out_dim': out_dim,

                        # 모델 아키텍처
                        'dim': self.args.dim,
                        'num_steps': self.args.num_steps,

                        # QAP 파라미터
                        'num_queries': self.args.q if self.args.q is not None else 1,
                        'num_heads': 8,  # 기본값 (main.py의 qap_params와 동일)
                        'use_side_channel': True,
                        'enable_time_features': enable_time_features,
                        'time_dim': time_dim,

                        # 학습 정보
                        'epoch': ep,
                        'best_val_loss': val_loss,
                        'client_id': cid
                    }, client_best_path)

                    # Scaler 저장
                    if self.data_factory is not None:
                        try:
                            scaler = self.data_factory.get_client_scaler(cid)
                            scaler_path = scaler_dir / f"client_{cid}.pkl"

                            with open(scaler_path, 'wb') as f:
                                pickle.dump(scaler, f)

                            printt(f"  Saved scaler: {scaler_path.name}")
                        except Exception as e:
                            printt(f"  ⚠️ Failed to save scaler for client {cid}: {e}")

                self.logger.info(f"New best model: {val_loss:.10f}")
                printt(f"New best global model saved! loss: {val_loss:.4f}")
                printt(f"All client best models saved at global best epoch {ep+1}")

            # Early stop check
            stop_now = self._early_stop(val_loss)

            ckpt_path = self._save_ckpt(ep)
            printt(f"Checkpoint saved: {ckpt_path}")
            printt(f"  - Train Loss: {agg['train']['loss']:.4f} | "
                   f"Val Metrics: loss: {agg['valid']['loss']:.4f} | "
                   f"r2: {agg['valid']['r2']:.4f} | "
                   f"mae: {agg['valid']['mae']:.4f} | "
                   f"cvrmse: {agg['valid']['cvrmse']:.4f}")

            # ⑦ TensorBoard (record global average for train split only)
            for k, v in agg["train"].items():
                self.writer.add_scalar(f"Global/{k}", v, ep)

            if stop_now:
                printt(f"Early‑stopping triggered (patience={self.patience})")
                break

        # 학습 완료 후 결과 요약 저장
        summary_results = {
            'mode': 'federated',
            'global_epochs': self.G,
            'best_global_loss': float(self.best_global),
            'num_clients': len(self.clients),
            'config': {
                'lr': self.lr,
                'optimizer': self.opt_name,
                'local_epochs': self.L,
                'early_stop': self.early_stop,
                'patience': self.patience,
                'model_name': self.args.model_name,
                'window': self.args.window,
                'horizon': self.args.horizon
            }
        }

        summary_path = self.res_dir / "training_summary.json"
        with open(summary_path, 'w', encoding='utf-8') as f:
            json.dump(summary_results, f, indent=2, ensure_ascii=False)

        printt(f"Training summary saved: {summary_path}")

        self.writer.close()
        printt("✅ Training completed")

    def _save_local_checkpoint(self, client_id, model, optimizer, epoch, best_val_loss, no_improve_count, best_state):
        """Save individual client checkpoint for local-only mode"""
        local_ckpt_dir = self.ckpt_dir / "local_checkpoints"
        local_ckpt_dir.mkdir(exist_ok=True)

        # Safely save RNG state
        rng_torch = None
        rng_cuda = None
        try:
            rng_torch = torch.get_rng_state()
            if torch.cuda.is_available():
                rng_cuda = torch.cuda.get_rng_state_all()
        except Exception as e:
            print(f"⚠️ Client {client_id} RNG state save error: {e}")

        checkpoint_data = {
            "client_id": client_id,
            "epoch": epoch + 1,  # Next epoch to start
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_val_loss": best_val_loss,
            "no_improve_count": no_improve_count,
            "best_state": best_state,
            "rng_states": {
                "python": random.getstate(),
                "numpy": np.random.get_state(),
                "torch": rng_torch,
                "cuda": rng_cuda
            }
        }

        ckpt_path = local_ckpt_dir / f"client_{client_id}_epoch_{epoch:03d}.pth"
        torch.save(checkpoint_data, ckpt_path, _use_new_zipfile_serialization=False)

        # Clean up old checkpoints (keep only latest 3)
        client_ckpts = sorted(local_ckpt_dir.glob(f"client_{client_id}_epoch_*.pth"))
        if len(client_ckpts) > 3:
            for old_ckpt in client_ckpts[:-3]:
                old_ckpt.unlink()

        return ckpt_path

    def _load_local_checkpoint(self, client_id):
        """Load individual client checkpoint for local-only mode"""
        local_ckpt_dir = self.ckpt_dir / "local_checkpoints"
        if not local_ckpt_dir.exists():
            return None

        # Find latest checkpoint for the client
        client_ckpts = sorted(local_ckpt_dir.glob(f"client_{client_id}_epoch_*.pth"))
        if not client_ckpts:
            return None

        latest_ckpt = client_ckpts[-1]

        try:
            checkpoint_data = torch.load(latest_ckpt, map_location=self.device, weights_only=False)
            print(f"✅ Client {client_id} checkpoint loaded: {latest_ckpt.name}")
            return checkpoint_data
        except Exception as e:
            print(f"❌ Client {client_id} checkpoint load failed: {e}")
            return None

    def _init_local_resume(self):
        """Initialize resume state for local-only mode"""
        if not hasattr(self, 'resume') or not self.resume:
            return

        print("\n🔄 Checking local-only mode resume...")

        resumed_clients = 0
        for client_id in self.clients.keys():
            checkpoint_data = self._load_local_checkpoint(client_id)
            if checkpoint_data:
                self.local_resume_data[client_id] = checkpoint_data
                resumed_clients += 1

        if resumed_clients > 0:
            print(f"✅ Found checkpoints for {resumed_clients} clients.")
            print("Each client will resume training from where it was interrupted.")
        else:
            print("No checkpoints found. Starting training from scratch.")

    def local_only_train(self):
        """Local training only mode - independent client training without federation"""
        printt("\n=== Local Training Only Mode Started ===")

        # Initialize resume state
        self._init_local_resume()

        printt(f"Number of clients: {len(self.clients)}")
        printt(f"Local epochs: {self.L}")
        
        # Directory for local results
        local_results_dir = self.res_dir / "local_results"
        local_results_dir.mkdir(exist_ok=True)

        # Store results for each client
        client_results = {}
        client_sizes = {}  # Data sizes for weighted average

        total_start = time.time()

        # Independent training for each client (with early stop)
        for cid in self.clients.keys():
            # Check resume information
            resume_info = self.local_resume_data.get(cid, None)
            start_epoch = 0

            if resume_info:
                printt(f"\n--- Client {cid} local training resumed (from epoch {resume_info['epoch']}) ---")
                start_epoch = resume_info['epoch']
            else:
                printt(f"\n--- Client {cid} local training started ---")

            # Calculate dataset size (for weighted average)
            train_size = len(self.loaders[cid]["train"].dataset)
            client_sizes[cid] = train_size
            printt(f"Client {cid} training data size: {train_size}")

            # Local training with early stop
            total_training_start = time.time()
            model = self.clients[cid].to(self.device)
            train_loader = self.loaders[cid]["train"]
            valid_loader = self.loaders[cid]["valid"]

            criterion = nn.MSELoss()
            opt_cls = optim.Adam if self.opt_name == "adam" else optim.SGD
            optimizer = opt_cls(model.parameters(), lr=self.lr)
            # Suppress GradScaler warnings
            import warnings
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", message=".*GradScaler.*")
                scaler = GradScaler()

            # Early stop variables
            best_val_loss = float("inf")
            no_improve_count = 0
            best_state = None
            start_epoch = 0

            # Restore resume state
            if resume_info:
                try:
                    model.load_state_dict(resume_info['model_state'])
                    optimizer.load_state_dict(resume_info['optimizer_state'])
                    best_val_loss = resume_info['best_val_loss']
                    no_improve_count = resume_info['no_improve_count']
                    best_state = resume_info['best_state']
                    start_epoch = resume_info['epoch']

                    # RNG 상태 복원 (안전하게)
                    if 'rng_states' in resume_info:
                        rng_states = resume_info['rng_states']
                        try:
                            random.setstate(rng_states['python'])
                            np.random.set_state(rng_states['numpy'])

                            # torch RNG state 복원
                            if 'torch' in rng_states and rng_states['torch'] is not None:
                                torch_state = rng_states['torch']
                                if isinstance(torch_state, torch.Tensor):
                                    torch.set_rng_state(torch_state)
                                else:
                                    print(f"  ⚠️ torch RNG state type mismatch: {type(torch_state)}")

                            # Restore cuda RNG state
                            if 'cuda' in rng_states and rng_states['cuda'] is not None and torch.cuda.is_available():
                                cuda_state = rng_states['cuda']
                                if isinstance(cuda_state, list) and all(isinstance(s, torch.Tensor) for s in cuda_state):
                                    torch.cuda.set_rng_state_all(cuda_state)
                                else:
                                    print(f"  ⚠️ cuda RNG state type mismatch: {type(cuda_state)}")
                        except Exception as e:
                            print(f"  ⚠️ Error restoring RNG state: {e}")
                    else:
                        print(f"  ⚠️ No RNG state information available")

                    printt(f"  State restored: best_val_loss={best_val_loss:.4f}, no_improve={no_improve_count}")
                except Exception as e:
                    printt(f"  ⚠️ State restore failed, starting from scratch: {e}")
                    start_epoch = 0

            model.train()

            for epoch in range(start_epoch, self.L):
                # Start timer for this epoch
                epoch_start = time.time()

                # Training phase
                epoch_loss = 0.0
                num_batches = 0

                for step, batch in enumerate(train_loader, 1):
                    x, y, tf = self._unpack_batch(batch)
                    x = x.to(self.device); y = y.to(self.device)
                    tf = tf.to(self.device) if tf is not None else None

                    if AMP_MODE == "amp":
                        ctx = autocast(device_type="cuda", dtype=torch.float16)
                    else:
                        ctx = autocast(dtype=torch.float16)

                    with ctx:
                        out = self._forward_model(model, x, tf)
                        if isinstance(out, (tuple, list)): out = out[0]
                        loss = criterion(out, y)
                        epoch_loss += loss.item()

                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad(set_to_none=True)
                    num_batches += 1

                # Calculate metrics (train, valid, test)
                model.eval()
                train_metrics = self._metrics_full(model, train_loader, cid)
                val_metrics = self._metrics_full(model, valid_loader, cid)
                test_metrics = self._metrics_full(model, self.loaders[cid]["test"], cid)
                model.train()

                # Add TensorBoard logging
                for metric_name, value in train_metrics.items():
                    self.writer.add_scalar(f'client_{cid}/train/{metric_name}', value, epoch)
                for metric_name, value in val_metrics.items():
                    self.writer.add_scalar(f'client_{cid}/val/{metric_name}', value, epoch)
                for metric_name, value in test_metrics.items():
                    self.writer.add_scalar(f'client_{cid}/test/{metric_name}', value, epoch)

                # Progress output
                epoch_time = time.time() - epoch_start
                from datetime import datetime
                current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

                print(f"{epoch:02d}    'train_client{cid}'    Time:{epoch_time:.2f}    {fmt(train_metrics)}    {current_time}")
                print(f"{epoch:02d}    'valid_client{cid}'    Time:{epoch_time:.2f}    {fmt(val_metrics)}    {current_time}")
                print(f"{epoch:02d}    'test_client{cid}'     Time:{epoch_time:.2f}    {fmt(test_metrics)}    {current_time}")
                print("-" * 130)

                val_loss = val_metrics['loss']

                # Early stopping check
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    no_improve_count = 0
                    best_state = model.state_dict().copy()
                    # Record best epoch
                    self.writer.add_scalar(f'client_{cid}/best_epoch', epoch, epoch)
                else:
                    no_improve_count += 1

                # Record early stopping progress
                self.writer.add_scalar(f'client_{cid}/no_improve_count', no_improve_count, epoch)
                self.writer.add_scalar(f'client_{cid}/best_val_loss', best_val_loss, epoch)

                # Save checkpoint every few epochs (optional)
                if epoch % 5 == 0 or epoch == self.L - 1:  # Every 5 epochs or last
                    ckpt_path = self._save_local_checkpoint(
                        cid, model, optimizer, epoch, best_val_loss, no_improve_count, best_state
                    )
                    printt(f"  Checkpoint saved: {ckpt_path.name}")

                # Early stop condition
                if self.early_stop and no_improve_count >= self.patience:
                    printt(f"  Early stop at epoch {epoch+1}/{self.L} (patience={self.patience})")
                    # Record early stop occurrence
                    self.writer.add_scalar(f'client_{cid}/early_stop_epoch', epoch, epoch)
                    if best_state is not None:
                        model.load_state_dict(best_state)
                    break
            
            # Save final checkpoint after client training completion
            final_ckpt_path = self._save_local_checkpoint(
                cid, model, optimizer, epoch, best_val_loss, no_improve_count, best_state
            )
            printt(f"Final checkpoint saved: {final_ckpt_path.name}")

            elapsed_time = time.time() - total_training_start

            # Restore to best model state for final metric calculation
            if best_state is not None:
                model.load_state_dict(best_state)
                printt(f"  Restored to best model state (validation loss: {best_val_loss:.4f})")

            model.eval()
            final_metrics = {
                'train': self._metrics_full(model, self.loaders[cid]["train"], cid),
                'valid': self._metrics_full(model, self.loaders[cid]["valid"], cid),
                'test': self._metrics_full(model, self.loaders[cid]["test"], cid)
            }

            # Record final training time
            self.writer.add_scalar(f'client_{cid}/training_time', elapsed_time, 0)
            self.writer.add_scalar(f'client_{cid}/epochs_trained', epoch + 1, 0)

            printt(f"Client {cid} training completed ({elapsed_time:.1f}s, {epoch+1}/{self.L} epochs)")
            printt(f"  Train: {fmt(final_metrics['train'])}")
            printt(f"  Valid: {fmt(final_metrics['valid'])}")
            printt(f"  Test:  {fmt(final_metrics['test'])}")

            # 모델 재구성에 필요한 정보 수집
            sample_batch = next(iter(train_loader))
            if len(sample_batch) == 3:
                x_sample, y_sample, time_features = sample_batch
                enable_time_features = time_features is not None
                time_dim = time_features.shape[-1] if enable_time_features else None
            else:
                x_sample, y_sample = sample_batch
                enable_time_features = False
                time_dim = None

            feature_count = x_sample.shape[-1]
            out_dim = y_sample.shape[-1] if y_sample.ndim >= 3 else 1

            # ✅ 완전한 체크포인트 저장
            client_model_path = local_results_dir / f"client_{cid}_best.pth"
            torch.save({
                # 모델 파라미터
                'state_dict': model.state_dict(),

                # 기본 정보
                'model_name': self.args.model_name,
                'window_len': self.args.window,
                'horizon': self.args.horizon,
                'feature_count': feature_count,
                'out_dim': out_dim,

                # 모델 아키텍처
                'dim': self.args.dim,
                'num_steps': self.args.num_steps,

                # QAP 파라미터
                'num_queries': self.args.q if self.args.q is not None else 1,
                'num_heads': 8,
                'use_side_channel': True,
                'enable_time_features': enable_time_features,
                'time_dim': time_dim,

                # 학습 정보
                'metrics': final_metrics,
                'best_val_loss': best_val_loss
            }, client_model_path)
            printt(f"  Best model saved: {client_model_path}")

            # ✅ Scaler 저장
            if self.data_factory is not None:
                try:
                    scaler = self.data_factory.get_client_scaler(cid)

                    # scalers 디렉토리 생성
                    scaler_dir = self.res_dir / "scalers"
                    scaler_dir.mkdir(parents=True, exist_ok=True)

                    scaler_path = scaler_dir / f"client_{cid}.pkl"

                    with open(scaler_path, 'wb') as f:
                        pickle.dump(scaler, f)

                    printt(f"  Scaler saved: {scaler_path}")
                except Exception as e:
                    printt(f"  ⚠️ Scaler save failed: {e}")
            
            # Store results
            client_results[cid] = {
                'metrics': final_metrics,
                'train_size': train_size,
                'training_time': elapsed_time,
                'epochs_trained': epoch + 1,
                'best_val_loss': best_val_loss
            }
        
        # Create overall best model by averaging client best models
        printt(f"\n--- Creating overall best model (average of client best models) ---")
        self._create_overall_best_model_from_clients(local_results_dir, client_sizes)

        # Calculate weighted average metrics of best models
        printt(f"\n--- Overall Results (Weighted Average of Best Models) ---")
        weighted_metrics = self._compute_weighted_metrics(client_results, client_sizes)

        for split in ['train', 'valid', 'test']:
            printt(f"Overall {split.capitalize()}: {fmt(weighted_metrics[split])}")

        # Save results to JSON
        summary_results = {
            'overall_metrics': weighted_metrics,
            'client_results': {str(cid): result for cid, result in client_results.items()},
            'client_sizes': client_sizes,
            'total_training_time': time.time() - total_start,
            'config': {
                'local_epochs': self.L,
                'lr': self.lr,
                'optimizer': self.opt_name
            }
        }

        summary_path = local_results_dir / "local_training_summary.json"
        with open(summary_path, 'w', encoding='utf-8') as f:
            json.dump(summary_results, f, indent=2, ensure_ascii=False)

        printt(f"\n✅ Local training completed! Results saved: {summary_path}")
        printt(f"Total time taken: {time.time() - total_start:.1f}s")

        # TensorBoard usage guide
        printt(f"\n📊 View training progress with TensorBoard:")
        printt(f"   tensorboard --logdir={self.tb_dir}")
        printt(f"   Access http://localhost:6006 in browser")
        
        self.writer.close()

    def _create_overall_best_model_from_clients(self, local_results_dir, client_sizes):
        """Create overall best model by averaging client best models (FedPer style)"""
        try:
            # Load all client best models
            client_states = {}
            total_size = sum(client_sizes.values())

            for cid in self.clients.keys():
                client_model_path = local_results_dir / f"client_{cid}_best.pth"
                if client_model_path.exists():
                    checkpoint = torch.load(client_model_path, map_location=self.device, weights_only=False)
                    client_state = checkpoint['state_dict']  # Extract state_dict from checkpoint
                    client_states[cid] = client_state
                    printt(f"  Loaded client_{cid} best model: {client_model_path.name}")
                else:
                    printt(f"  ⚠️ Client_{cid} best model not found: {client_model_path}")

            if not client_states:
                printt("  ❌ No client best models found. Cannot create overall best model.")
                return

            # Create averaged global model (only shared parameters)
            global_state = {}
            first_cid = next(iter(client_states.keys()))

            # Collect shared parameters (FedPer: only backbone parameters)
            for param_name in client_states[first_cid].keys():
                if param_name.startswith("backbone.") and not param_name.endswith("slot_embed.weight"):
                    # This is a shared parameter - average across clients
                    weighted_sum = None
                    total_weight = 0.0

                    for cid, client_state in client_states.items():
                        if param_name in client_state:
                            weight = client_sizes[cid] / total_size
                            param_tensor = client_state[param_name].to(self.device)

                            if weighted_sum is None:
                                weighted_sum = param_tensor * weight
                            else:
                                weighted_sum += param_tensor * weight
                            total_weight += weight

                    if weighted_sum is not None and total_weight > 0:
                        global_state[param_name] = weighted_sum / total_weight

            # Save the averaged global best model
            global_best_path = self.res_dir / "global_best.pth"
            torch.save(global_state, global_best_path)
            printt(f"  ✅ Overall best model saved: {global_best_path}")
            printt(f"  - Averaged {len(global_state)} shared parameters from {len(client_states)} clients")

            # Update server global state for consistency
            self.server.global_state.update(global_state)

        except Exception as e:
            printt(f"  ❌ Error creating overall best model: {e}")
            import traceback
            traceback.print_exc()

    # -------------- Helpers --------------
    def _select_clients(self) -> List[int]:
        ids = list(self.clients.keys())
        if 0 < self.frac < 1:
            k = max(1, int(len(ids) * self.frac))
            return random.sample(ids, k)
        return ids

    def _train_selected(self, sel: List[int], ep: int):
        # 각 클라이언트의 메트릭과 데이터 크기 수집
        client_metrics = {'train': {}, 'valid': {}, 'test': {}}
        client_weights = {}
        
        for cid in sel:  # 로컬 학습 + 평가
            # 로컬 학습
            met = self._local_train(cid)
            t = met["time"]
            
            # 각 split별 메트릭 계산
            train_metrics = self._metrics_full(self.clients[cid], self.loaders[cid]["train"], cid)
            valid_metrics = self._metrics_full(self.clients[cid], self.loaders[cid]["valid"], cid)
            test_metrics = self._metrics_full(self.clients[cid], self.loaders[cid]["test"], cid)
            
            # 메트릭과 가중치 저장
            client_metrics['train'][cid] = train_metrics
            client_metrics['valid'][cid] = valid_metrics
            client_metrics['test'][cid] = test_metrics
            client_weights[cid] = len(self.loaders[cid]["train"].dataset)
            
            # 클라이언트별 상세 출력
            timestamp = f"{datetime.now():%Y-%m-%d %H:%M:%S}"
            printt(f"{ep:02d}\t'train_client{cid}'\tTime:{t:.2f}\t{fmt(train_metrics)}\t{timestamp}")
            printt(f"{ep:02d}\t'valid_client{cid}'\tTime:{t:.2f}\t{fmt(valid_metrics)}\t{timestamp}")
            printt(f"{ep:02d}\t'test_client{cid}'\tTime:{t:.2f}\t{fmt(test_metrics)}\t{timestamp}")
            printt("-"*130)
        
        return client_metrics, client_weights

    def _unpack_batch(self, batch: Tuple[torch.Tensor, ...]):
        # (x, y, tf) 또는 (x, y)
        if isinstance(batch, (list, tuple)):
            if len(batch) == 3:
                x, y, tf = batch
            elif len(batch) == 2:
                x, y = batch; tf = None
            else:
                raise ValueError(f"Unexpected batch tuple length: {len(batch)}")
        else:
            raise ValueError("Batch must be (x,y) or (x,y,tf).")
        return x, y, tf

    def _forward_model(self, model: nn.Module, x: torch.Tensor, tf: torch.Tensor | None):
        # (x, tf) 시그니처 지원 시 사용, 아니면 (x) 폴백
        if tf is not None:
            try:  return model(x, tf)
            except TypeError: pass
        return model(x)

    def _local_train(self, cid: int):
        model = self.clients[cid].to(self.device)
        train_loader = self.loaders[cid]["train"]

        criterion = nn.MSELoss()
        params    = model.parameters()
        opt_cls   = optim.Adam if self.opt_name == "adam" else optim.SGD
        optimizer = opt_cls(params, lr=self.lr)
        # Suppress GradScaler warnings
        import warnings
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message=".*GradScaler.*")
            scaler = GradScaler()
        accum     = self.grad_accum

        # ---- Train loop ----
        model.train()
        t0 = time.time()
        optimizer.zero_grad(set_to_none=True)
        last_loss = None  # ← 요청대로 노출

        # 🔧 배치 수 제한 제거 (전체 배치 처리)
        # max_batches_per_epoch = 50  # 비활성화
        
        for epoch in range(self.L):                   # local_epochs
            for step, batch in enumerate(train_loader, 1):
                # 🔧 배치 제한 제거됨 - 전체 배치 처리
                # if step > max_batches_per_epoch:
                #     print(f"      ↳ Reached max batches ({max_batches_per_epoch}) for epoch {epoch+1}")
                #     break
                    
                x, y, tf = self._unpack_batch(batch)
                x = x.to(self.device); y = y.to(self.device)
                tf = tf.to(self.device) if tf is not None else None

                # AMP autocast
                if AMP_MODE == "amp":
                    ctx = autocast(device_type="cuda", dtype=torch.float16)
                else:
                    ctx = autocast(dtype=torch.float16)

                with ctx:
                    out = self._forward_model(model, x, tf)
                    if isinstance(out, (tuple, list)): out = out[0]
                    loss = criterion(out, y) / accum
                    last_loss = float(loss.item() * accum)

                scaler.scale(loss).backward()

                if step % accum == 0 or step == len(train_loader):
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad(set_to_none=True)
                
                # 🔧 진행상황 출력 (필요시 활성화)
                # if step % 100 == 0:  # 매 100개 배치마다
                #     print(f"      ↳ Client {cid} - Processed {step} batches, loss: {last_loss:.4f}")

        dur = time.time() - t0
        model.eval()

        # 로컬 단계 평가는 제거. 학습 시간과 마지막 배치 손실만 보고.
        return {"time": dur, "train_loss": (0.0 if last_loss is None else last_loss)}

    # ---- Fast Torch‑only metrics (라운드 끝 전역 평균) ----
    @torch.no_grad()
    def _metrics_fast(self, model: nn.Module, loader: DataLoader):
        model.eval()  # 평가 모드로 전환 (Dropout 비활성화)
        sse = 0.0; sae = 0.0; sum_y = 0.0; sum_y2 = 0.0; N = 0
        it = enumerate(loader, 1)
        for i, batch in it:
            x, y, tf = self._unpack_batch(batch)
            x = x.to(self.device); y = y.to(self.device)
            tf = tf.to(self.device) if tf is not None else None
            out = self._forward_model(model, x, tf)
            if isinstance(out, (tuple, list)): out = out[0]

            diff = (y - out)
            sse += (diff * diff).sum().item()
            sae += diff.abs().sum().item()
            sum_y  += y.sum().item()
            sum_y2 += (y * y).sum().item()
            N += y.numel()


        if N == 0:
            return {"loss": 0.0, "r2": 0.0, "mae": 0.0, "cvrmse": 0.0}

        mse_v  = sse / N
        mae_v  = sae / N
        sst  = max(sum_y2 - (sum_y * sum_y) / N, 0.0)
        r2_v   = (1.0 - sse / sst) if sst > 0 else 0.0
        rmse = (sse / N) ** 0.5
        mean_abs_y = abs(sum_y) / N
        cvr = (rmse / (mean_abs_y + 1e-8)) * 100.0
        return {"loss": mse_v, "r2": r2_v, "mae": mae_v, "cvrmse": cvr}

    @torch.no_grad()
    def _metrics_full(self, model: nn.Module, loader: DataLoader, client_idx: int = None):
        """
        전체 데이터를 수집한 후 한 번에 메트릭 계산
        역정규화를 적용하여 원본 스케일에서 메트릭 계산
        """
        model.eval()
        
        # 기존 코드 (주석 처리)
        # # 모든 예측값과 실제값을 수집
        # all_preds = []
        # all_targets = []
        # 
        # for batch in loader:
        #     x, y, tf = self._unpack_batch(batch)
        #     x = x.to(self.device)
        #     y = y.to(self.device) 
        #     tf = tf.to(self.device) if tf is not None else None
        #     
        #     # 모델 예측
        #     out = self._forward_model(model, x, tf)
        #     if isinstance(out, (tuple, list)): 
        #         out = out[0]
        #     
        #     # CPU로 이동 후 리스트에 추가
        #     all_preds.append(out.cpu())
        #     all_targets.append(y.cpu())
        
        # 새 코드: 역정규화 적용
        all_preds = []
        all_targets = []
        
        # Scaler 가져오기
        scaler = None
        if client_idx is not None and self.data_factory is not None:
            try:
                scaler = self.data_factory.get_client_scaler(client_idx)
            except Exception as e:
                print(f"Warning: Could not get scaler for client {client_idx}: {e}")
                scaler = None
        
        for batch in loader:
            x, y, tf = self._unpack_batch(batch)
            x = x.to(self.device)
            y = y.to(self.device)
            tf = tf.to(self.device) if tf is not None else None
            
            # 모델 예측
            out = self._forward_model(model, x, tf)
            if isinstance(out, (tuple, list)): 
                out = out[0]
            
            # 역정규화 없이 정규화된 스케일에서 메트릭 계산
            all_preds.append(out.cpu())
            all_targets.append(y.cpu())
        
        # 데이터가 없는 경우 처리
        if len(all_preds) == 0:
            return {"loss": 0.0, "r2": 0.0, "mae": 0.0, "cvrmse": 0.0}
        
        # 전체 데이터 통합
        preds = torch.cat(all_preds, dim=0)
        targets = torch.cat(all_targets, dim=0)
        
        # numpy로 변환
        preds_np = preds.numpy()
        targets_np = targets.numpy()
        
        # 메트릭 계산 (프로젝트의 common.function 함수 사용)
        from common.function import single_mse as mse, single_mae as mae, r2_score, cvrmse
        
        return {
            "loss": float(mse(targets_np, preds_np)),
            "r2": float(r2_score(
                targets_np.reshape(targets_np.shape[0], -1),
                preds_np.reshape(preds_np.shape[0], -1)
            )),
            "mae": float(mae(targets_np, preds_np)),
            "cvrmse": float(cvrmse(targets_np, preds_np))
        }


    def _compute_weighted_metrics(self, client_results, client_sizes):
        """Calculate weighted average metrics based on client data sizes"""
        total_size = sum(client_sizes.values())
        weighted_metrics = {'train': {}, 'valid': {}, 'test': {}}

        # Get metric key list from first client
        first_cid = next(iter(client_results.keys()))
        metric_keys = client_results[first_cid]['metrics']['train'].keys()

        # Calculate weighted average for each split
        for split in ['train', 'valid', 'test']:
            for metric_key in metric_keys:
                weighted_sum = 0.0
                for cid, result in client_results.items():
                    weight = client_sizes[cid] / total_size
                    weighted_sum += result['metrics'][split][metric_key] * weight
                weighted_metrics[split][metric_key] = weighted_sum
        
        return weighted_metrics

    def aggregate_client_metrics(self, client_metrics: Dict[int, Dict], weights: Dict[int, int]) -> Dict:
        """
        클라이언트별 메트릭을 데이터 크기 기반으로 가중 평균
        
        Args:
            client_metrics: {client_id: {'loss': ..., 'mae': ...}}
            weights: {client_id: num_samples}
        
        Returns:
            global_metrics: 가중 평균된 메트릭
        """
        if not client_metrics or not weights:
            return {"loss": 0.0, "r2": 0.0, "mae": 0.0, "cvrmse": 0.0}
        
        # 총 샘플 수 계산
        total_samples = sum(weights.values())
        if total_samples == 0:
            return {"loss": 0.0, "r2": 0.0, "mae": 0.0, "cvrmse": 0.0}
        
        # 메트릭 키 목록 (첫 번째 클라이언트에서 가져오기)
        first_client_id = next(iter(client_metrics.keys()))
        metric_keys = client_metrics[first_client_id].keys()
        
        # 가중 평균 계산
        weighted_metrics = {}
        for metric_key in metric_keys:
            weighted_sum = 0.0
            for client_id, metrics in client_metrics.items():
                weight = weights.get(client_id, 0)
                weighted_sum += metrics[metric_key] * weight
            
            weighted_metrics[metric_key] = weighted_sum / total_samples
        
        return weighted_metrics

    # -------------- Checkpoint / Resume / EarlyStop --------------
    def _ckpt_path(self, ep): return self.ckpt_dir/f"epoch_{ep:03d}.pth"

    def _save_ckpt(self, ep):
        # # 기존 RNG state 안전하게 저장 (주석 처리)
        # rng_torch = None
        # rng_cuda = None
        # try:
        #     rng_torch = torch.get_rng_state()
        #     if torch.cuda.is_available():
        #         rng_cuda = torch.cuda.get_rng_state_all()
        # except Exception as e:
        #     print(f"⚠️ RNG state 저장 중 오류: {e}")

        # 논문 재현성을 위한 직접 RNG state 저장 (오류 시 즉시 실패)
        data = {
            "epoch": ep+1,
            "server_state": self.server.global_state,
            "global_best": self.best_global,
            "client_best": dict(self.best_client),
            "no_improve": self.no_improve,
            "rng_python": random.getstate(),
            "rng_numpy": np.random.get_state(),
            "rng_torch": torch.get_rng_state(),  # 직접 저장 (ByteTensor)
            "rng_cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
        }
        # ★ 방법 B: 저장 파일명을 로더 규칙(epoch_***)에 맞춤
        path = self.ckpt_dir / f"epoch_{ep:03d}.pth"
        torch.save(data, path, _use_new_zipfile_serialization=False)
        return path

    def _load_ckpt(self):
        ckpts = sorted(self.ckpt_dir.glob("epoch_*.pth"))
        if not ckpts: return
        ck = torch.load(ckpts[-1], map_location=self.device, weights_only=False)
        self.server.global_state = ck["server_state"]
        self.best_global = ck.get("global_best", float("inf"))
        self.best_client.update(ck.get("client_best", {}))
        self.no_improve = ck.get("no_improve", 0)
        self.start_ep = ck["epoch"]
        random.setstate(ck["rng_python"]); np.random.set_state(ck["rng_numpy"])

        # # 기존 RNG state 복원 (타입 체크 및 변환) - 주석 처리
        # try:
        #     # RNG state 복원 (타입 체크 및 변환)
        #     if "rng_torch" in ck and ck["rng_torch"] is not None:
        #         rng_state = ck["rng_torch"]
        #         if isinstance(rng_state, torch.Tensor):
        #             torch.set_rng_state(rng_state)
        #         else:
        #             print(f"⚠️ RNG torch state 타입 불일치: {type(rng_state)}")
        #
        #     if "rng_cuda" in ck and ck["rng_cuda"] is not None and torch.cuda.is_available():
        #         cuda_state = ck["rng_cuda"]
        #         if isinstance(cuda_state, list) and all(isinstance(s, torch.Tensor) for s in cuda_state):
        #             torch.cuda.set_rng_state_all(cuda_state)
        #         else:
        #             print(f"⚠️ RNG cuda state 타입 불일치: {type(cuda_state)}")
        # except (TypeError, KeyError, RuntimeError) as e:
        #     print(f"⚠️ RNG state 복원 실패 - 건너뜁니다: {e}")

        # 논문 재현성을 위한 직접 RNG state 복원 (타입 체크 제거)
        if "rng_torch" in ck and ck["rng_torch"] is not None:
            torch.set_rng_state(ck["rng_torch"])  # 직접 복원

        if "rng_cuda" in ck and ck["rng_cuda"] is not None and torch.cuda.is_available():
            torch.cuda.set_rng_state_all(ck["rng_cuda"])  # 직접 복원

        self.logger.info(f"Resumed from {ckpts[-1].name} (epoch {self.start_ep})")

    def _early_stop(self, val_loss: float) -> bool:
        if not self.early_stop: return False
        # best_global 업데이트는 이미 위에서 처리됨
        if val_loss > self.best_global:
            self.no_improve += 1
        # val_loss == self.best_global인 경우는 카운터 유지
        return self.no_improve >= self.patience




# # Federated / training.py – sequential‑client FederatedTrainer (FedPer)
# from __future__ import annotations
# from pathlib import Path
# from typing import Dict, List, Tuple
# from datetime import datetime
# import time, json, random, logging
# from collections import defaultdict

# import numpy as np
# import torch
# import torch.nn as nn
# import torch.optim as optim

# # AMP 호환
# try:
#     from torch.amp import GradScaler, autocast
#     AMP_MODE = "amp"
# except ImportError:
#     from torch.cuda.amp import GradScaler, autocast
#     AMP_MODE = "cuda"

# from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter

# # --- 네 프로젝트 유틸/서버/메트릭 import (기존 경로 유지) ---
# from federated.server import Server
# from common.function import single_mse as mse, single_mae as mae, r2_score, cvrmse

# DISPLAY_ORDER = ("loss", "r2", "mae", "cvrmse")
# def printt(msg: str): print(msg, flush=True)
# def fmt(m: dict) -> str:
#     return (f"loss:{m['loss']:8.4f} | r2:{m['r2']:7.3f} | "
#             f"mae:{m['mae']:8.4f} | cvrmse:{m['cvrmse']:8.4f}")

# class FederatedTrainer:
#     def __init__(
#         self,
#         args,
#         server: Server,
#         clients: List[nn.Module],                  # 각 클라 모델 (ForecastWrapper or backbone+head)
#         loaders: Dict[int, Dict[str, DataLoader]], # {cid: {"train":..., "valid":..., "test":...}}
#         *,
#         device: str = "cuda",
#         global_epochs: int = 50,
#         local_epochs: int = 1,
#         lr: float = 1e-3,
#         optimizer: str = "adam",
#         frac: float = 1.0,         # 매 라운드 참여 비율(1.0 = 전 클라)
#         experiment: str = "exp",
#         seed: int = 0,
#         resume: bool = False,
#         early_stop: bool = True,
#         patience: int = 5,
#         grad_accum: int = 1,
#     ) -> None:
#         # logger
#         self.logger = logging.getLogger("Trainer")
#         if not self.logger.handlers:
#             h = logging.StreamHandler()
#             h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s │ %(message)s", "%m%d %H:%M:%S"))
#             self.logger.addHandler(h)
#         self.logger.setLevel(logging.INFO)

#         # seeds
#         random.seed(seed); np.random.seed(seed)
#         torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

#         # attrs
#         self.device = torch.device(device if torch.cuda.is_available() else "cpu")
#         self.G, self.L = int(global_epochs), int(local_epochs)
#         self.lr, self.opt_name = float(lr), optimizer.lower()
#         self.frac = min(max(frac, 0.), 1.)
#         self.early_stop, self.patience = bool(early_stop), int(patience)
#         self.grad_accum = max(1, int(grad_accum))

#         base = Path("output")/experiment
#         self.ckpt_dir = base/"checkpoints"; self.ckpt_dir.mkdir(parents=True, exist_ok=True)
#         self.tb_dir   = base/"logs";        self.tb_dir.mkdir(parents=True, exist_ok=True)
#         self.res_dir  = base/"results";     self.res_dir.mkdir(parents=True, exist_ok=True)
#         self.writer = SummaryWriter(self.tb_dir.as_posix())

#         # models & data
#         self.server  = server.to(self.device)
#         self.clients = {i: c.to(self.device) for i, c in enumerate(clients)}
#         self.loaders = loaders

#         # trackers
#         self.start_ep = 0
#         self.best_global = float("inf")
#         self.no_improve = 0
#         self.best_client: Dict[int, float] = defaultdict(lambda: float("inf"))
#         self.args = args

#         if resume:
#             self._load_ckpt()

#     # ============== Global Round Loop ==============
#     def train(self):
#         for ep in range(self.start_ep, self.G):
#             t0 = time.time()
#             printt(f"\n=== [Global Epoch {ep+1}/{self.G}] ===")

#             # ① 선택 클라이언트 샘플링 & 서버→클라 브로드캐스트(공유 파라미터만)
#             sel = self._select_clients()
#             printt(f"Selected clients: {sel}")
#             self.server.broadcast_shared([self.clients[cid] for cid in sel])

#             # ② 로컬 학습 (선택된 클라들을 for-loop으로 ‘순차’ 학습)
#             agg = self._train_selected(sel, ep)

#             # ③ 서버 집계 (FedPer: 공유(backbone.*)만 평균/갱신)
#             self.server.aggregate([self.clients[cid] for cid in sel])

#             # ④ 라운드 로그
#             dt = time.time() - t0
#             printt(f"Global ({dt:6.1f}s)")
#             printt(f"{ep:02d} Global Train - {fmt(agg['train'])}")
#             printt(f"{ep:02d} Global Valid - {fmt(agg['valid'])}")
#             printt(f"{ep:02d} Global Test  - {fmt(agg['test'])}")

#             # ⑤ 얼리스탑 & 체크포인트
#             val_loss = agg["valid"]["loss"]
#             stop_now = self._early_stop(val_loss)
#             if val_loss == self.best_global:
#                 torch.save(self.server.global_state_dict(), self.res_dir/"global_best.pth")
#                 self.logger.info(f"New best model: {val_loss:.10f}")
#                 printt(f"New best model saved! loss: {self.best_global:.4f}")

#             ckpt_path = self._save_ckpt(ep)
#             printt(f"Checkpoint saved: {ckpt_path}")
#             printt(f"  - Train Loss: {agg['train']['loss']:.4f} | "
#                    f"Val Metrics: loss: {agg['valid']['loss']:.4f} | "
#                    f"r2: {agg['valid']['r2']:.4f} | "
#                    f"mae: {agg['valid']['mae']:.4f} | "
#                    f"cvrmse: {agg['valid']['cvrmse']:.4f}")

#             # ⑥ TensorBoard (train split만 전역평균 기록)
#             for k, v in agg["train"].items():
#                 self.writer.add_scalar(f"Global/{k}", v, ep)

#             if stop_now:
#                 printt(f"Early‑stopping triggered (patience={self.patience})")
#                 break

#         self.writer.close()
#         printt("✅ Training completed")

#     # -------------- Helpers --------------
#     def _select_clients(self) -> List[int]:
#         ids = list(self.clients.keys())
#         if 0 < self.frac < 1:
#             k = max(1, int(len(ids) * self.frac))
#             return random.sample(ids, k)
#         return ids

#     def _train_selected(self, sel: List[int], ep: int):
#         from collections import defaultdict
#         agg = defaultdict(lambda: defaultdict(float))

#         for cid in sel:  # ← 여기서 “for loop 순차 학습”
#             met = self._local_train(cid)      # {'train':..., 'valid':..., 'test':..., 'time':...}
#             t   = met.pop("time")

#             for split in ("train", "valid", "test"):
#                 row = met[split]
#                 metric_line = " \t ".join(
#                     f"{k}: {row[k]:.4f}" for k in DISPLAY_ORDER if k in row
#                 )
#                 printt(f"{ep:02d}\t'{split}_client{cid}'\tTime:{t:.2f}\t{metric_line}\t"
#                        f"{datetime.now():%Y-%m-%d %H:%M:%S}")

#                 for k, v in row.items():
#                     agg[split][k] += v

#             printt("-"*120)
#             # client best (train-loss 기준)
#             if met["train"]["loss"] < self.best_client[cid]:
#                 self.best_client[cid] = met["train"]["loss"]
#                 torch.save(self.clients[cid].state_dict(), self.res_dir / f"client_{cid}_best.pth")

#         n = max(1, len(sel))
#         for split in ("train", "valid", "test"):
#             for k in list(agg[split].keys()):
#                 agg[split][k] /= n
#         return agg

#     def _unpack_batch(self, batch: Tuple[torch.Tensor, ...]):
#         # (x, y, tf) 또는 (x, y)
#         if isinstance(batch, (list, tuple)):
#             if len(batch) == 3:
#                 x, y, tf = batch
#             elif len(batch) == 2:
#                 x, y = batch; tf = None
#             else:
#                 raise ValueError(f"Unexpected batch tuple length: {len(batch)}")
#         else:
#             raise ValueError("Batch must be (x,y) or (x,y,tf).")
#         return x, y, tf

#     def _forward_model(self, model: nn.Module, x: torch.Tensor, tf: torch.Tensor | None):
#         # (x, tf) 시그니처 지원 시 사용, 아니면 (x) 폴백
#         if tf is not None:
#             try:  return model(x, tf)
#             except TypeError: pass
#         return model(x)

#     def _local_train(self, cid: int):
#         model = self.clients[cid].to(self.device)
#         train_loader = self.loaders[cid]["train"]
#         valid_loader = self.loaders[cid]["valid"]
#         test_loader  = self.loaders[cid]["test"]

#         criterion = nn.MSELoss()
#         params    = model.parameters()
#         opt_cls   = optim.Adam if self.opt_name == "adam" else optim.SGD
#         optimizer = opt_cls(params, lr=self.lr)
#         scaler    = GradScaler()
#         accum     = self.grad_accum

#         # ---- Train loop ----
#         model.train()
#         t0 = time.time()
#         optimizer.zero_grad(set_to_none=True)

#         for _ in range(self.L):                   # local_epochs
#             for step, batch in enumerate(train_loader, 1):
#                 x, y, tf = self._unpack_batch(batch)
#                 x = x.to(self.device); y = y.to(self.device)
#                 tf = tf.to(self.device) if tf is not None else None

#                 # AMP autocast
#                 if AMP_MODE == "amp":
#                     ctx = autocast(device_type="cuda", dtype=torch.float16)
#                 else:
#                     ctx = autocast(dtype=torch.float16)

#                 with ctx:
#                     out = self._forward_model(model, x, tf)
#                     if isinstance(out, (tuple, list)): out = out[0]
#                     loss = criterion(out, y) / accum

#                 scaler.scale(loss).backward()

#                 if step % accum == 0 or step == len(train_loader):
#                     scaler.step(optimizer)
#                     scaler.update()
#                     optimizer.zero_grad(set_to_none=True)

#         dur = time.time() - t0
#         model.eval()

#         return {
#             "train": self._metrics(model, train_loader),
#             "valid": self._metrics(model, valid_loader),
#             "test":  self._metrics(model,  test_loader),
#             "time":  dur,
#         }

#     @torch.no_grad()
#     def _metrics(self, model: nn.Module, loader: DataLoader):
#         preds, trues = [], []
#         for batch in loader:
#             x, y, tf = self._unpack_batch(batch)
#             x = x.to(self.device); tf = (tf.to(self.device) if tf is not None else None)
#             out = self._forward_model(model, x, tf)
#             if isinstance(out, (tuple, list)): out = out[0]
#             preds.append(out.cpu()); trues.append(y)

#         yp = torch.cat(preds).numpy()
#         yt = torch.cat(trues).numpy()
#         return {
#             "loss":   float(mse(yt, yp)),
#             "r2":     float(r2_score(yt.reshape(yt.shape[0], -1), yp.reshape(yp.shape[0], -1))),
#             "mae":    float(mae(yt, yp)),
#             "cvrmse": float(cvrmse(yt, yp)),
#         }

#     # -------------- Checkpoint / Resume / EarlyStop --------------
#     def _ckpt_path(self, ep): return self.ckpt_dir/f"epoch_{ep:03d}.pth"

#     def _save_ckpt(self, ep):
#         data = {
#             "epoch": ep+1,
#             "server_state": self.server.global_state_dict(),
#             "global_best": self.best_global,
#             "client_best": dict(self.best_client),
#             "no_improve": self.no_improve,
#             "rng_python": random.getstate(),
#             "rng_numpy": np.random.get_state(),
#             "rng_torch": torch.get_rng_state(),
#             "rng_cuda": torch.cuda.get_rng_state_all(),
#         }
#         path = self.ckpt_dir / f"global_epoch_{ep:03d}.pth"
#         torch.save(data, path)
#         return path

#     def _load_ckpt(self):
#         ckpts = sorted(self.ckpt_dir.glob("epoch_*.pth"))
#         if not ckpts: return
#         ck = torch.load(ckpts[-1], map_location=self.device, weights_only=False)
#         self.server.load_global_state_dict(ck["server_state"], strict=True)
#         self.best_global = ck.get("global_best", float("inf"))
#         self.best_client.update(ck.get("client_best", {}))
#         self.no_improve = ck.get("no_improve", 0)
#         self.start_ep = ck["epoch"]
#         random.setstate(ck["rng_python"]); np.random.set_state(ck["rng_numpy"])
#         torch.set_rng_state(ck["rng_torch"]); torch.cuda.set_rng_state_all(ck["rng_cuda"])
#         self.logger.info(f"Resumed from {ckpts[-1].name} (epoch {self.start_ep})")

#     def _early_stop(self, val_loss: float) -> bool:
#         if not self.early_stop: return False
#         if val_loss < self.best_global:
#             self.best_global = val_loss; self.no_improve = 0
#         else:
#             self.no_improve += 1
#         return self.no_improve >= self.patience
