import numpy as np
from datasets.dataloaders import DataloadManager
from datasets.utils import get_data_dict, split_data_dirichlet


def get_loaders(n_clients, configs):
    dm = DataloadManager(configs)
    data_path = "./datasets/har"
    acc_path = data_path + "/combined_acc.pkl"
    gyro_path = data_path + "/combined_gyro.pkl"
    acc_dict = get_data_dict(acc_path)
    gyro_dict = get_data_dict(gyro_path)
    # Shuffle acc_dict and gyro_dict with the same order
    combined = list(zip(acc_dict, gyro_dict))
    np.random.shuffle(combined)
    acc_dict, gyro_dict = zip(*combined)
    split_idx = int(len(acc_dict) * 0.8)
    acc_dict_train, acc_dict_val = acc_dict[:split_idx], acc_dict[split_idx:]
    gyro_dict_train, gyro_dict_val = gyro_dict[:split_idx], gyro_dict[split_idx:]

    labels = np.array([acc[-2] for acc in acc_dict_train])
    data_indices = split_data_dirichlet(labels, n_clients, configs.non_iid_alpha)
    client_dataloaders = []
    for i in range(n_clients):
        print(len(data_indices[i]))
        acc_ = [acc_dict_train[idx] for idx in data_indices[i]]
        gyro_ = [gyro_dict_train[idx] for idx in data_indices[i]]

        client_dataloaders.append(dm.set_dataloader(
            acc_, gyro_, shuffle=True,
            default_feat_shape_a=np.array([256, 3]),
            default_feat_shape_b=np.array([256, 3])
        ))
        
    test_dataloader = dm.set_dataloader(
        acc_dict_val, gyro_dict_val, shuffle=False,
        default_feat_shape_a=np.array([256, 3]),
        default_feat_shape_b=np.array([256, 3])
    )
    return client_dataloaders, test_dataloader