# main.py – Federated training with FedPer/QAP

from __future__ import annotations
import argparse, random, sys, time, json
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn

from federated.server import Server
from federated.training import FederatedTrainer
from federated.client import create_client_model

# Import models and factories
from model import iTransformer, SpikeRNN, Spikformer, iSpikeformer, TimeMixer, DLinear
from model.federated_model_factory import build_backbone, infer_dims

# ---------------- utils ----------------
def set_seed(seed: int = 0):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

# ---------------- args -----------------
def get_args():
    p = argparse.ArgumentParser("FedQAP - Federated Time Series Forecasting with QAP")
    # Data
    p.add_argument("--dataset", type=str, required=True, help="Time series dataset name (e.g., electricity, traffic)")
    p.add_argument("--num_clients", type=int, default=10, help="Number of clients")
    p.add_argument("--max_features", type=int, default=None, help="Maximum number of features per client")
    p.add_argument("--feature_overlap", type=float, default=0.0, help="Feature overlap ratio between clients (0.0~1.0)")
    # Fed algorithm (FedPer fixed)
    # Window/batch
    p.add_argument("--window",   type=int, default=48)
    p.add_argument("--horizon",  type=int, default=24)    # To be used for future prediction
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--num_workers", type=int, default=0)
    p.add_argument("--timeenc",  type=int, default=1)
    p.add_argument("--time_freq", type=str, default=None)
    # Model/hyperparameters
    p.add_argument("--model_name", choices=["itransformer","spikernn","spikformer","ispikeformer","timemixer","dlinear"], default="spikernn")
    p.add_argument("--dim", type=int, default=64)
    p.add_argument("--num_steps", type=int, default=4)
    p.add_argument("--q",   type=int, default=None, help="QAP num_queries (default=1 if not specified)")
    p.add_argument("--alignment_method", type=str, choices=["qap", "projection"],
                   default="qap", help="Feature alignment method: qap or projection")
    # Fed settings
    p.add_argument("--global_epochs", type=int, default=50)
    p.add_argument("--local_epochs",  type=int, default=1)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--optimizer", choices=["adam","sgd"], default="adam")
    p.add_argument("--frac", type=float, default=1.0)
    p.add_argument("--early_stop", action="store_true")
    p.add_argument("--patience", type=int, default=5)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--experiment", type=str, default="run")
    p.add_argument("--resume", action="store_true")
    p.add_argument("--file_name", type=str, default=None,
    help="Save under output/<file_name>. Takes priority over --experiment if specified.")
    p.add_argument("--local_only", action="store_true",
    help="Perform local training only (disable federated learning)")
    return p.parse_args()

def build_client_model(sample_batch, args, client_info, data_factory=None, client_idx=0) -> nn.Module:
    """Build FedPer client model with QAP or Projection alignment"""

    # For raw data, sample_batch format is (x_raw, y_raw, time_features) or (x_raw, y_raw)
    if len(sample_batch) == 3:
        x_raw, y_raw, time_features = sample_batch
        enable_time_features = time_features is not None
        time_dim = time_features.shape[-1] if enable_time_features else None
    else:
        x_raw, y_raw = sample_batch
        enable_time_features = False
        time_dim = None

    B, L, C_client = x_raw.shape

    # Get client feature info
    F_client = client_info.get('feature_count', C_client)
    client_id = client_info.get('client_id', f'client_{client_idx}')

    print(f"[{client_id}] Raw input: {x_raw.shape} -> Output features: {F_client} (method: {args.alignment_method})")

    # Alignment method selection
    if args.alignment_method == "qap":
        # ========== QAP 방식 (기존) ==========
        from model.federated_model_factory import QAP_ServerModule, QAPBackboneWrapper

        qap_module = QAP_ServerModule(
            d_model=args.dim,
            num_heads=8,
            num_queries=getattr(args, 'q', 1) or 1,
            use_side_channel=True
        )

        backbone = build_backbone(
            name=args.model_name,
            input_size=args.dim,
            hidden_size=args.dim,
            num_steps=args.num_steps,
            max_length=args.window
        )

        qap_backbone = QAPBackboneWrapper(
            backbone=backbone,
            qap_module=qap_module,
            F_client=C_client,
            time_dim=time_dim if enable_time_features else None,
            fuse="cat"
        )

    elif args.alignment_method == "projection":
        # ========== Projection 방식 (신규) ==========
        from model.federated_model_factory import ProjectionBackboneWrapper

        backbone = build_backbone(
            name=args.model_name,
            input_size=args.dim,
            hidden_size=args.dim,
            num_steps=args.num_steps,
            max_length=args.window
        )

        qap_backbone = ProjectionBackboneWrapper(
            backbone=backbone,
            F_client=C_client,
            d_model=args.dim,
            time_dim=time_dim if enable_time_features else None,
            fuse="cat"
        )

    else:
        raise ValueError(f"Unknown alignment_method: {args.alignment_method}")

    # Get output dimensions
    horizon = args.horizon
    out_dim = F_client  # predict same features as input

    # Create FedPer client model using ForecastWrapper
    return create_client_model(qap_backbone, horizon, out_dim)

# ---------------- main -----------------
def main():
    args = get_args()
    set_seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    exp_name = args.file_name or args.experiment

    # Check current working directory
    import os
    print(f"Current working directory: {os.getcwd()}")
    print(f"Dataset path: {os.path.abspath('dataset/fl_clients')}")

    print("┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
    print(f"┃  Torch    : {torch.__version__}")
    print(f"┃  Device   : {args.device} (cuda available={torch.cuda.is_available()})")
    print(f"┃  Seed     : {args.seed}")
    print(f"┃  File Name: {exp_name}")
    # Determine dataset method
    print(f"┃  Dataset  : {args.dataset}")
    print(f"┃  Clients  : {args.num_clients}")
    print(f"┃  Max Feat : {args.max_features}")
    print(f"┃  Overlap  : {args.feature_overlap}")
    print(f"┃  Algorithm: FedPer (with QAP)")
    print("┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")

    # 1) Create split-wise loaders (all clients) - QAP is handled in model
    def get_split_loaders(split: str, universal_scaler=None):
        # Universal time series dataset
        from data_provider.data_factory import DataFactory
        factory = DataFactory(
            dataset_file=args.dataset,
            num_clients=args.num_clients,
            max_features=args.max_features,
            feature_overlap=args.feature_overlap,
            split_ratios=(0.8, 0.1, 0.1),
            seed=args.seed,
            universal_scaler=universal_scaler
        )

        # Print client-wise feature assignment summary
        if split == "train":  # Print only once in train split
            print(f"\n[Feature Assignment] Client-wise Feature Assignment Summary ({args.dataset} dataset)")
            print("="*80)
            all_client_info = factory.get_client_info()

            for client_id in sorted(all_client_info.keys()):
                info = all_client_info[client_id]
                feature_indices = info['feature_indices']
                feature_count = info['feature_count']
                total_features = info['total_features']

                print(f"   - {client_id}: {feature_count} features")
                if len(feature_indices) <= 8:
                    print(f"   └ Features: {feature_indices}")
                else:
                    first_4 = feature_indices[:4]
                    last_4 = feature_indices[-4:]
                    print(f"   └ Features: {first_4} ... {last_4} (total {feature_count})")

            # Overall statistics
            feature_counts = [info['feature_count'] for info in all_client_info.values()]
            print(f"\n[Feature Stats] Feature Assignment Statistics:")
            print(f"   └ Average features/client: {sum(feature_counts)/len(feature_counts):.1f}")
            print(f"   └ Min-Max: {min(feature_counts)}-{max(feature_counts)} features")
            print(f"   └ Total features: {all_client_info[client_id]['total_features']}")
            print("="*80)

        # Create raw dataloaders without QAP preprocessing (QAP will be in model)
        id_loaders = factory.create_client_dataloaders(
            split=split,
            window_len=args.window,
            horizon=args.horizon,
            batch_size=args.batch_size,
            num_workers=args.num_workers
        )
        return factory, id_loaders

    # --- Create split loaders ---
    import time
    total_start = time.time()

    t1 = time.time()
    print("Creating data loaders...")
    # Create train factory and extract scaler
    data_factory, tr_list = get_split_loaders("train")
    train_scaler = data_factory.get_client_scaler()

    # Pass train scaler to val/test factory
    _, va_list = get_split_loaders("val", universal_scaler=train_scaler)
    _, te_list = get_split_loaders("test", universal_scaler=train_scaler)
    t2 = time.time()
    
    # Print dataset size summary (create_client_dataloaders returns List[DataLoader])
    total_train = sum(len(loader.dataset) for loader in tr_list)
    total_valid = sum(len(loader.dataset) for loader in va_list)
    total_test = sum(len(loader.dataset) for loader in te_list)
    print(f"Data loaders created! ({t2-t1:.1f}s)")
    print(f"Dataset sizes - Train: {total_train}, Valid: {total_valid}, Test: {total_test}")
    
    total_loader_time = time.time() - total_start
    print(f"Total loader creation time: {total_loader_time:.1f}s")

    # --- Secure string client_id list ---
    print("Extracting client IDs...")
    # create_client_dataloaders returns List[DataLoader], need to get client_ids from factory
    client_ids = [f"client_{i}" for i in range(data_factory.num_clients)]
    id2idx = {cid: i for i, cid in enumerate(client_ids)}  # ← newly added
    print(f"Found {len(client_ids)} clients")

    # --- Form expected by trainer: int key → split: loader ---
    loaders: Dict[int, Dict[str, torch.utils.data.DataLoader]] = {}
    def _assign(loader_list, split_name):
        for i, loader in enumerate(loader_list):
            loaders.setdefault(i, {})[split_name] = loader

    _assign(tr_list, "train")
    _assign(va_list, "valid")
    _assign(te_list, "test")

    # (Optional) Let's show only string ID for debug
    print(f"Debug: client_ids = {client_ids}")

    # --- Check in/out dim with sample batch ---
    print("Getting sample batch for dimension check...")
    sample_batch = next(iter(loaders[0]["train"]))
    
    # Handle both (x, y) and (x, y, time_features) formats
    if len(sample_batch) == 3:
        sx, sy, time_features = sample_batch
        time_dim = time_features.shape[-1] if time_features is not None else None
        enable_time_features = time_features is not None
    else:
        sx, sy = sample_batch
        time_features = None
        time_dim = None
        enable_time_features = False
    
    in_dim = int(sx.shape[-1])  # Raw feature count
    out_dim = int(sy.shape[-1]) if sy.ndim >= 3 else 1
    print(f"   raw_input_dim={in_dim}, output_dim={out_dim}")
    print(f"   Sample batch shapes: x={sx.shape}, y={sy.shape}")
    print(f"   Sample batch length: {len(sample_batch)}")
    if enable_time_features:
        print(f"   Time features shape: {time_features.shape}")

    # Get feature count for first client to determine QAP parameters
    first_client_id = client_ids[0]
    first_client_info = data_factory.get_client_info(first_client_id)
    F_client = first_client_info.get('feature_count', in_dim)

    # Adjust feature dimensions
    if enable_time_features:
        # in_dim is raw_features + time_features, so calculate only raw_features
        raw_features = F_client
        print(f"   Universal dataset: raw_features={raw_features}, time_features_dim={time_dim}")
    else:
        raw_features = F_client

    # 3) Server - will be created with FedPer and QAP alignment
    print(f"Building global model with QAP alignment for FedPer...")
    server = None  # Will be created after first client model

    # 4) Client models
    print("Building client models...")
    client_models: List[nn.Module] = []
    for idx, cid in enumerate(client_ids):   # ← client_ids 사용
        sample_batch = next(iter(loaders[idx]["train"]))
        client_info = data_factory.get_client_info(cid)
        model = build_client_model(
            sample_batch, args, client_info,
            data_factory=data_factory,
            client_idx=idx
        ).to(device)

        # 실제 모델 출력 shape 확인 (raw data with QAP)
        model.eval()
        with torch.no_grad():
            # For raw data: (x_raw, y_raw, time_features) or (x_raw, y_raw)
            if len(sample_batch) == 3:
                x_raw, y_raw, time_features = sample_batch
                x_test = x_raw.to(device)
                tf_test = time_features.to(device) if time_features is not None else None
                # Forward pass with raw data and time features
                actual_output = model(x_test, tf_test) if tf_test is not None else model(x_test)
            else:
                x_raw, y_raw = sample_batch
                x_test = x_raw.to(device)
                # Forward pass with raw data only
                actual_output = model(x_test)

        print(f"   client[{idx}] ({cid}) raw_input_shape={x_raw.shape} -> QAP -> backbone -> head -> model_output_shape={actual_output.shape}")
        print(f"   Raw: {x_raw.shape[-1]} -> QAP: {args.dim} -> Final: {actual_output.shape[-1]}")
        client_models.append(model)

        # Create server after first client model for FedPer
        if server is None:
            print(f"Creating server for FedPer...")

            # FedPer: Create simple server (no global backbone needed since preprocessing is in DataFactory)
            server = Server(global_backbone=None)

            print(f"Server created for FedPer (no global backbone needed)")

    # 5) Trainer with FedPer configuration
    print(f"Creating FederatedTrainer with FedPer...")

    trainer = FederatedTrainer(
        args=args,
        server=server,
        clients=client_models,     # nn.Module 리스트 (수정된 v6 트레이너와 호환)
        loaders=loaders,
        data_factory=data_factory,  # ✅ Scaler 접근용
        device=str(device),
        global_epochs=args.global_epochs,
        local_epochs=args.local_epochs,
        lr=args.lr,
        optimizer=args.optimizer,
        frac=args.frac,
        seed=args.seed,
        early_stop=args.early_stop,
        patience=args.patience,
        experiment=exp_name,
        resume=args.resume,
        fed_algorithm="fedper",
        fed_algorithm_config={},
    )
    print("FederatedTrainer created!")

    # Determine training mode
    if args.local_only:
        print(f"\nStarting local-only training ({args.local_epochs} local epochs per client)...")
        print(f"Early stopping enabled: {args.early_stop}, Patience: {args.patience}")
        try:
            trainer.local_only_train()
            print("Local training completed successfully!")
            print("Program will now exit (local-only mode).")
            return  # Exit the main function after local training
        except Exception as e:
            print(f"Local training failed: {e}")
            import traceback
            traceback.print_exc()
            return  # Exit on error too
    else:
        print(f"\nStarting federated training ({args.global_epochs} global epochs)...")
        try:
            trainer.train()
            print("Federated training completed successfully!")
        except Exception as e:
            print(f"Federated training failed: {e}")
            import traceback
            traceback.print_exc()

    # 6) Save metadata (for future predict/visualization)
    meta = {
        "data_cfg": {
            "window_len": args.window,
            "batch_size": args.batch_size,
            "timeenc": args.timeenc,
            "time_freq": args.time_freq,
            "num_workers": args.num_workers,
        },
        "model_name": args.model_name,
        "input_size": in_dim,
        "hidden_size": args.dim,
        "num_steps": args.num_steps,
        "max_len": args.window,
        "horizon": args.horizon,
        "backbone_kw": ({"num_queries": int(args.q)} if args.q is not None else {}),
    }
    res_dir = Path("output") / args.experiment / "results"
    res_dir.mkdir(parents=True, exist_ok=True)
    with open(res_dir / "meta.json", "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2, ensure_ascii=False)

    # Prediction/visualization later:
    # from predict import predict
    # predict(args.experiment, mode="fedper", device=str(device))

    print("Training finished.")

if __name__ == "__main__":
    t0 = time.time(); main(); print(f"Total time: {time.time()-t0:.1f}s")