# load client train/test data; randomly assign data to users
dataset_train, dataset_test, dict_users_train, dict_users_test = get_data(args)
'''
Now we have:
    dataset_{train,test} = train/test data (and labels)
    dict_users_{train,test}[j] = list of data instance IDs assigned to the j-th user (len = shard_per_user * #-data-instances-per-shard)
'''

'''
Fed training.
'''

def init_server_V0(args, prior, post):
    c1 = post['n0'] / (args.num_users + prior['nu0'])
    c2 = c1 * prior['lambda0']
    c3 = c1 * args.num_users * args.eps2
    net_ = get_model(args)
    with torch.no_grad():
        for param, src in zip(net_.parameters(), post['m0'].parameters()):
            param.copy_(torch.clamp(c1*prior['Sigma0'] + c2*(src-prior['mu0'])**2 + c3, min=args.V0_min))
    return net_

# create a backbone network (dropout network)
net = get_model(args)
d = np.sum(p.numel() for p in net.parameters())  # number of params

# client training data sizes
Ds = [len(dict_users_train[idx]) for idx in range(args.num_users)]

# prior p(\phi) (fixed, constant)
prior = { 'mu0': 0.0, 'Sigma0': 1.0, 'lambda0': 1.0, 'nu0': d+2.0 }

# build q(\phi) = NIW(m0, V0, l0, n0)
post_phi = { 'm0': net, 'V0': None, 'l0': prior['lambda0']+sum(Ds), 'n0': prior['nu0']+sum(Ds), 'd': d }
post_phi['V0'] = init_server_V0(args, prior, post_phi)
m0_vec = weights2vec(post_phi['m0'], args.local_part)
V0_vec = weights2vec(post_phi['V0'], args.local_part)

# build q_i(\theta_i)
post_ths = { 'ms': [copy.deepcopy(post_phi['m0']) for _ in range(args.num_users)], 'p': 1-args.pdrop, 'eps2': args.eps2 }

lr = args.lr
for epoch in range(args.epochs):  # each round

    # client sampling
    m = max(int(args.frac * args.num_users), 1)
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)

    # local update q_i(\theta_i) (simulate clients that work in parallel)
    m_vecs = []
    for idx in idxs_users:
        local_updater = LocalUpdater(args=args, dataset=dataset_train, idxs=dict_users_train[idx], p=post_ths['p'])
        net = local_updater.train(post_phi=post_phi, local_part=args.local_part, lr=lr, m0_vec=m0_vec, V0_vec=V0_vec)
        m_vecs.append(weights2vec(net, args.local_part))
    m_vecs = torch.stack(m_vecs, 0)

    # server-side update of m0 and V0
    c1 = args.num_users * post_ths['p'] / (args.num_users + prior['lambda0'])
    c2 = prior['lambda0'] / (args.num_users + prior['lambda0'])
    m0_vec =  c1 * m_vecs.mean(0) + c2 * prior['mu0']
    c1 = post_phi['n0'] / (args.num_users + prior['nu0'])
    c2 = c1 * prior['lambda0']
    c3 = c1 * args.num_users * args.eps2
    c4 = c1 * args.num_users
    V0_vec = torch.clamp(c1*prior['Sigma0'] + c2*(m0_vec-prior['mu0'])**2 + c3 + c4*(post_ths['p']*(m_vecs**2).mean(0) - 2*post_ths['p']*m_vecs.mean(0)*m0_vec + m0_vec**2), min=args.V0_min)
    vec2weights(m0_vec, post_phi['m0'], args.local_part)
    vec2weights(V0_vec, post_phi['V0'], args.local_part)

    if (epoch + 1) in [args.epochs//2, (args.epochs*3)//4]:
        lr *= 0.1

    # test evaluation / report results
    if (epoch + 1) % args.test_freq == 0:
        acc_test, loss_test, logprobs_test, targets_test, data_ratio = test_img_local_all(post_phi, args, dataset_test, dict_users_test)
        acc_test_avg = (acc_test*data_ratio).sum()
        loss_test_avg = (loss_test*data_ratio).sum()

'''
Personalization.
'''

if args.personalization:

    # load the model q(\phi)
    model_state_dict = torch.load(save_path)
    for mi, net in enumerate([post_phi['m0'], post_phi['V0']]):
        net.load_state_dict(model_state_dict[mi], strict=False)
    m0_vec = weights2vec(post_phi['m0'], args.local_part)
    V0_vec = weights2vec(post_phi['V0'], args.local_part)
    
    for idx in range(args.num_users):
        personal_updater = LocalUpdater(args=args, dataset=dataset_train, idxs=dict_users_train[idx], p=post_ths['p'])
        net = personal_updater.train(post_phi=post_phi, local_part=args.ft_part, lr=args.ft_lr, local_eps=args.ft_ep)
        test_acc, test_loss, logprobs, targets = single_test_img_local(net, dataset_test, args, idxs=dict_users_test[idx])
