from torch.utils.data import Dataset, DataLoader
import torch
from datasets.utils import get_data_dict, split_data_dirichlet

class AVEDataset(Dataset):

    def __init__(self, images, audios, labels):
        self.images = images
        self.audios = audios
        self.labels = labels


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        # # audio
        # samples, rate = librosa.load(self.audio[idx], sr=22050)
        # resamples = np.tile(samples, 3)[:22050*3]
        # resamples[resamples > 1.] = 1.
        # resamples[resamples < -1.] = -1.
        #
        # spectrogram = librosa.stft(resamples, n_fft=512, hop_length=353)
        # spectrogram = np.log(np.abs(spectrogram) + 1e-7)
        # #mean = np.mean(spectrogram)
        # #std = np.std(spectrogram)
        # #spectrogram = np.divide(spectrogram - mean, std + 1e-9)

        image =self.images[idx]
        spectrogram = self.audios[idx]
        
        # label
        label = self.labels[idx]

        return spectrogram, image, label


def get_loaders(n_clients, configs):

    # transform = transforms.Compose([
    #     transforms.ToPILImage(),
    #     transforms.Resize((28, 28)),
    #     transforms.ToTensor()
    # ])
    data_path = "./datasets/ave/"
    train_path = data_path + "/merged_data.pkl"
    test_path = data_path + "/merged_test_data.pkl"
    train_dict = get_data_dict(train_path)
    test_dict = get_data_dict(test_path)
    # Shuffle acc_dict and gyro_dict with the same order
    video = train_dict['images']
    audio = train_dict['spectrograms']

    labels = train_dict['labels']
    data_indices = split_data_dirichlet(labels, n_clients, configs.non_iid_alpha)
    client_dataloaders = []
    for i in range(n_clients):
        video_ = video[data_indices[i]]
        audio_ = audio[data_indices[i]]
        labels_ = labels[data_indices[i]]
        audio_ = torch.unsqueeze(audio_, dim=1)
        client_dataloaders.append(DataLoader(AVEDataset(video_, audio_, labels_), batch_size=configs.batch_size, shuffle=True))
        
    test_dataloader = DataLoader(AVEDataset(test_dict['images'], test_dict['spectrograms'].unsqueeze(dim=1), test_dict['labels']), batch_size=configs.batch_size, shuffle=False)
    return client_dataloaders, test_dataloader