import pickle

from flwr.common import Context
from flwr.client import ClientApp

from src.core.utils import (
    get_client_specific_data,
    get_dataloader, 
    get_net_builder,
    load_config)

from src.algorithms import get_client_alg


def client_fn(context: Context):
    # Load config
    cfg_path = context.run_config["config_path"]
    cfgs = load_config(cfg_path)
    data_cfgs = cfgs['Dataset']['Client']
    
    # # Identify client
    cid = int(context.node_config["partition-id"])
    num_partitions = int(context.node_config["num-partitions"])
    num_clients = cfgs['Dataset']['num_clients']
    
    assert num_clients == num_partitions, \
        f"[!] Mismatch: YAML num_clients={num_clients} vs simulation num_partitions={num_partitions}"
    

    # Load client-specific dataset
    with open(cfgs['Dataset']['clients_path'], "rb") as f:  
        clients_data = pickle.load(f)['clients_dict']
    
    c_data = get_client_specific_data(cfgs, cid, clients_set=clients_data)
    train_loader = get_dataloader(dset=c_data,
                                  batch_size=data_cfgs['bs'],
                                  shuffle=True,
                                  num_workers=data_cfgs['num_workers'],
                                  drop_last=False)
    
    # Model (backbone)
    _net_builder = get_net_builder(net_name=cfgs['Model']['net'], 
                                   from_name=cfgs['Model']['net_from_name'])
    # net_builder = _net_builder(num_classes=cfgs['Dataset']['num_classes'])
    
    # Initialize client
    client = get_client_alg(alg=cfgs["client_alg"],
                            cid=cid,
                            config=cfgs,
                            net_builder=_net_builder,
                            train_loader=train_loader)
    
    return client.to_client()

# Flower ClientApp
app = ClientApp(
    client_fn,
)
