import numpy as np
from sklearn.metrics import accuracy_score

def calculate_acc(client_list):
    average_train_acc_list = []
    average_test_acc_list = []
    train_sample_list = []
    test_sample_list = []
    for client in client_list:
        average_train_acc_list.append(client.get_train_acc())
        average_test_acc_list.append(client.get_test_acc())
        train_sample_list.append(client.X_train.shape[0])
        test_sample_list.append(client.X_test.shape[0])
    return np.asarray(average_train_acc_list),np.asarray(train_sample_list), np.asarray(average_test_acc_list),np.asarray(test_sample_list)

def calcualte_acc_simple(average_train_acc_list,train_sample_list,average_test_acc_list,test_sample_list):
    client_mean_train = np.mean(average_train_acc_list)
    client_mean_test = np.mean(average_test_acc_list)
    average_mean_train = np.sum(average_train_acc_list*train_sample_list)/np.sum(train_sample_list)
    average_mean_test = np.sum(average_test_acc_list*test_sample_list)/np.sum(test_sample_list)
    return client_mean_train, client_mean_test, average_mean_train, average_mean_test

def calcualte_acc_by_group(client_list,client_index_of_each_group):
    a,b,c,d = calculate_acc(client_list)
    group_result = []
    # client_mean_train, client_mean_test, average_mean_train, average_mean_test
    print("average acc of train datasets across all clients (equal weight)")
    print("average acc of test datasets across all clients (equal weight)")
    print("average acc of train datasets across all clients (weight by number of data)")
    print("average acc of test datasets across all clients (weight by number of data)")
    g_a,g_b,g_c,g_d= calcualte_acc_simple(a,b,c,d)
    group_result.append((g_a,g_b,g_c,g_d))
    
    for clients_indexs in client_index_of_each_group:
        g_a,g_b,g_c,g_d= calcualte_acc_simple(a[clients_indexs],b[clients_indexs],c[clients_indexs],d[clients_indexs])
        group_result.append((g_a,g_b,g_c,g_d))
    return group_result