import os
import shutil

from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig

from src.core.utils import (
    get_logger,
    get_server_dataset, 
    get_dataloader, 
    get_net_builder,
    load_config)

from src.algorithms import get_strategy, get_server_alg


def server_fn(context: Context):
    # Load config
    cfg_path = context.run_config["config_path"]
    print(f"[DEBUG] Loading config from: {cfg_path}")
    cfgs = load_config(cfg_path)
    print(f"[DEBUG] Config keys: {list(cfgs.keys())}")
    print(f"[DEBUG] server_alg: {cfgs.get('server_alg', 'NOT_FOUND')}")
    print(f"[DEBUG] client_alg: {cfgs.get('client_alg', 'NOT_FOUND')}")
    print(f"[DEBUG] strategy: {cfgs.get('strategy', 'NOT_FOUND')}")
    print(f"[DEBUG] All top-level config keys and values:")
    for key, value in cfgs.items():
        if isinstance(value, dict):
            print(f"  {key}: {{dict with {len(value)} keys}}")
        else:
            print(f"  {key}: {value} (type: {type(value).__name__})")
    
    data_cfgs = cfgs['Dataset']['Server']
    train_cfgs = cfgs['Training']['Server']

    # Copy config
    save_path = os.path.join(cfgs['save_dir'], cfgs['save_name'])
    os.makedirs(save_path, exist_ok=True)
    shutil.copyfile(cfg_path, os.path.join(save_path, 'config.yaml'))
    
    # Logger
    logger = get_logger(cfgs['save_name'], save_path)
    
    # Dataset
    server_data = get_server_dataset(cfgs)
    train_loader = get_dataloader(dset=server_data['train'],
                                  batch_size=data_cfgs['train_bs'],
                                  shuffle=True,
                                  num_workers=data_cfgs['num_workers'],
                                  drop_last=True,
                                  data_sampler=data_cfgs['data_sampler'])
    test_loader = get_dataloader(dset=server_data['test'],
                                 batch_size=data_cfgs['test_bs'],
                                 num_workers=data_cfgs['num_workers'])
    
    # Model (backbone)
    _net_builder = get_net_builder(net_name=cfgs['Model']['net'], 
                                   from_name=cfgs['Model']['net_from_name'])

    # Init Server Model
    server_alg_value = cfgs['server_alg']
    print(f"[DEBUG] About to call get_server_alg with: '{server_alg_value}' (type: {type(server_alg_value)})")
    server = get_server_alg(alg=server_alg_value,
                            config=cfgs,
                            net_builder=_net_builder,
                            train_loader=train_loader,
                            test_loader=test_loader,
                            logger=logger)
    print(f"[DEBUG] get_server_alg returned: {server} (type: {type(server)})")
    
    # Warm-up
    if train_cfgs['warmup_epochs'] != 0:
        server.warm_up(epochs=train_cfgs['warmup_epochs'])

    # Init Parameters
    init_params = ndarrays_to_parameters(server.get_model_parameters())
    
    # Strategy
    strategy_alg = cfgs['strategy']
    strategy_args = cfgs.get('strategy_args', {})
    strategy_args.setdefault("initial_parameters", init_params)
        
    strategy = get_strategy(alg=strategy_alg,
                            server=server,
                            **strategy_args)

    config = ServerConfig(num_rounds=cfgs['Training']['total_round'])
    
    return ServerAppComponents(strategy=strategy, config=config)

# Create ServerApp
app = ServerApp(server_fn=server_fn)