from torch.utils.data import Dataset, DataLoader
from datasets.utils import split_data_dirichlet
import torch
def load_data(dir="./datasets/modelnet"):
    modal1 = torch.load(dir + "/train_1.pt", weights_only=False)
    modal2 = torch.load(dir + "/train_2.pt", weights_only=False)
    labels = torch.load(dir + "/train_labels.pt", weights_only=False)
    return modal1, modal2, labels

def load_test_data(dir="./datasets/modelnet"):
    modal1 = torch.load(dir + "/test_1.pt", weights_only=False)
    modal2 = torch.load(dir + "/test_2.pt", weights_only=False)
    labels = torch.load(dir + "/test_labels.pt", weights_only=False)
    modal1 = modal1.squeeze(2)
    modal2 = modal2.squeeze(2)
    return modal1, modal2, labels


class ModelNet40(Dataset):
    def __init__(self, modal1, modal2, labels, transform=None):
        self.modal1 = modal1
        self.modal2 = modal2
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return self.labels.shape[0]
    def __getitem__(self, idx):
        if self.transform:
            modal1 = self.transform(self.modal1[idx])
            modal2 = self.transform(self.modal2[idx])
        else:
            modal1 = self.modal1[idx]
            modal2 = self.modal2[idx]
        labels = self.labels[idx]
        return [modal1, modal2, labels]

def get_loaders(n_clients, configs):
    batch_size = configs.batch_size
    alpha = configs.non_iid_alpha

    modal1, modal2, labels = load_data()
    data_indices = split_data_dirichlet(labels, n_clients, alpha)
    client_dataloaders = []
    for indices in data_indices:
        dataset = ModelNet40(modal1[indices], modal2[indices], labels[indices])
        #print(len(dataset))
        client_dataloaders.append(DataLoader(dataset, batch_size=batch_size, shuffle=True))
    
    test_modal1, test_modal2, test_labels = load_test_data()
    test_dataset = ModelNet40(test_modal1, test_modal2, test_labels)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return client_dataloaders, test_dataloader