"""
Author:*** Time:2024/05/15
"""
from local_update_method.FedTOGA_client import *


def train_FedTOGA(data_obj, act_prob,
                  learning_rate, batch_size, epoch, com_amount, test_per,
                  weight_decay, model_func, init_model, alpha_coef, beta,
                  sch_step, sch_gamma, rho, kappa, Np, rand_seed=0, lr_decay_per_round=1):
    n_client = data_obj.n_client
    client_x = data_obj.client_x;
    client_y = data_obj.client_y

    cent_x = np.concatenate(client_x, axis=0)
    cent_y = np.concatenate(client_y, axis=0)

    weight_list = np.asarray([len(client_y[i]) for i in range(n_client)])
    weight_list = weight_list / np.sum(weight_list) * n_client

    train_perf = np.zeros((com_amount, 2))
    test_perf = np.zeros((com_amount, 2))
    divergence_perf = np.zeros(com_amount)
    n_par = len(get_mdl_params([model_func()])[0])
    hist_params_diffs = np.zeros((n_client, n_par)).astype('float32')
    init_par_list = get_mdl_params([init_model], n_par)[0]
    client_params_list = np.ones(n_client).astype('float32').reshape(-1, 1) * init_par_list.reshape(1,
                                                                                                    -1)  # n_client X n_par
    client_models = list(range(n_client))
    delta = np.zeros(n_par).astype('float32')

    avg_model = model_func().to(device)
    avg_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))

    server_model = model_func().to(device)
    server_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))

    all_model = model_func().to(device)
    all_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))
    all_model_param = get_mdl_params([all_model], n_par)[0]
    for i in range(com_amount):
        inc_seed = 0
        while (True):
            np.random.seed(i + rand_seed + inc_seed)
            act_list = np.random.uniform(size=n_client)
            act_clients = act_list <= act_prob
            selected_clients = np.sort(np.where(act_clients)[0])
            inc_seed += 1
            if len(selected_clients) != 0:
                break

        print('Communication Round', i + 1, flush=True)
        print('Selected Clients: %s' % (', '.join(['%2d' % item for item in selected_clients])))
        all_model_param_tensor = torch.tensor(all_model_param, dtype=torch.float32, device=device)

        del client_models
        client_models = list(range(n_client))
        delta_sum = np.zeros(n_par).astype('float32')
        for client in selected_clients:
            train_x = client_x[client]
            train_y = client_y[client]

            client_models[client] = model_func().to(device)

            model = client_models[client]
            # Warm start from current avg model
            model.load_state_dict(copy.deepcopy(dict(all_model.named_parameters())))
            for params in model.parameters():
                params.requires_grad = True
            # Scale down
            alpha_coef_adpt = alpha_coef / weight_list[client]  # adaptive alpha coef
            # print(alpha_coef_adpt)
            hist_params_diffs_curr = torch.tensor(hist_params_diffs[client], dtype=torch.float32, device=device)
            client_models[client] = train_model_toga(model, model_func, alpha_coef_adpt, beta,
                                                     all_model_param_tensor, hist_params_diffs_curr,
                                                     torch.tensor(delta, dtype=torch.float32, device=device),
                                                     train_x, train_y, learning_rate * (lr_decay_per_round ** i),
                                                     batch_size, epoch, 5, weight_decay,
                                                     data_obj.dataset, sch_step, sch_gamma, rho, kappa, Np, print_verbose=False)
            curr_model_par = get_mdl_params([client_models[client]], n_par)[0]
            hist_params_diffs[client] += curr_model_par - all_model_param
            client_params_list[client] = curr_model_par
            delta_sum += curr_model_par - all_model_param

        delta = -delta_sum / (len(selected_clients) * epoch)
        avg_mdl_param_sel = np.mean(client_params_list[selected_clients], axis=0)
        all_model_param = avg_mdl_param_sel + np.mean(hist_params_diffs, axis=0)
        all_model = set_client_from_params(model_func().to(device), all_model_param)
        server_model = set_client_from_params(model_func(), np.mean(client_params_list, axis=0))

        if (i + 1) % test_per == 0:
            loss_test, acc_test = get_acc_loss(data_obj.test_x, data_obj.test_y,
                                               server_model, data_obj.dataset, 0)
            print("****FedTOGA Cur All Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                  % (i + 1, acc_test, loss_test))
            test_perf[i] = [loss_test, acc_test]

            loss_test, acc_test = get_acc_loss(cent_x, cent_y,
                                               server_model, data_obj.dataset, 0)
            print("****FedTOGA Cur All Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
                  % (i + 1, acc_test, loss_test), flush=True)

            train_perf[i] = [loss_test, acc_test]
            divergence = see_divergence(avg_mdl_param_sel, selected_clients, client_params_list)
            divergence_perf[i] = divergence
            print(divergence)
        # Freeze model
        for params in server_model.parameters():
            params.requires_grad = False

    return test_perf, train_perf, divergence_perf



def see_divergence(servermodel, selected_clients, client_params_list):
    divergence = 0
    for i in selected_clients:
        divergence += np.linalg.norm(client_params_list[i] - servermodel)**2
    # divergence = np.linalg.norm(client_params_list[selected_clients] - servermodel)**2

    return divergence/len(selected_clients)


def get_params_list_with_shape(model, param_list, device):
    vec_with_shape = []
    idx = 0
    for param in model.parameters():
        length = param.numel()
        vec_with_shape.append(param_list[idx:idx + length].reshape(param.shape).to(device))
    return vec_with_shape