from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch
from datasets.utils import get_data_dict, split_data_dirichlet

class IMEOCAPDataset(Dataset):
    def __init__(self, data):
        self.A_feat = data['A_feat']
        self.V_feat = data['V_feat']
        self.L_feat = data['L_feat']
        self.labels = data['label']
        self.lengths = data['lengths']
        # self.int2name = data['int2name']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'A_feat': self.A_feat[idx],   # Tensor (length, 130)
            'V_feat': self.V_feat[idx],   # Tensor (length, ?)
            'L_feat': self.L_feat[idx],   # Tensor (length, ?)
            'label': self.labels[idx],    # Tensor
            'length': self.lengths[idx],  # Tensor
            # 'int2name': self.int2name[idx] # str
        }

def collate_fn(batch):
    # batch is a list of samples (dicts)
    
    A_feats = [item['A_feat'] for item in batch]
    V_feats = [item['V_feat'] for item in batch]
    L_feats = [item['L_feat'] for item in batch]
    labels = torch.stack([item['label'] for item in batch], dim=0)
    lengths = torch.stack([item['length'] for item in batch], dim=0)
    # int2names = [item['int2name'] for item in batch]

    # Pad sequences
    A_feats_padded = pad_sequence(A_feats, batch_first=True, padding_value=0)  # (batch_size, max_len, 130)
    V_feats_padded = pad_sequence(V_feats, batch_first=True, padding_value=0)  # (batch_size, max_len, ?)
    L_feats_padded = pad_sequence(L_feats, batch_first=True, padding_value=0)  # (batch_size, max_len, ?)

    return A_feats_padded, V_feats_padded, L_feats_padded, labels


def get_loaders(n_clients, configs):
    data_path = "./datasets/cap/cap.pkl"
    combined = get_data_dict(data_path)
    data_size = len(combined["label"])
    split_idx = int(data_size * 0.8)
    train_data = {
        'A_feat': [combined['A_feat'][i] for i in range(split_idx)],
        'V_feat': [combined['V_feat'][i] for i in range(split_idx)],
        'L_feat': [combined['L_feat'][i] for i in range(split_idx)],
        'label': combined['label'][range(split_idx)],
        'lengths': combined['lengths'][range(split_idx)],
        # 'int2name': [combined['int2name'][i] for i in range(split_idx)]
    }
    test_data = {
        'A_feat': [combined['A_feat'][i] for i in range(split_idx, data_size)],
        'V_feat': [combined['V_feat'][i] for i in range(split_idx, data_size)],
        'L_feat': [combined['L_feat'][i] for i in range(split_idx, data_size)],
        'label': combined['label'][range(split_idx, data_size)],
        'lengths': combined['lengths'][range(split_idx, data_size)],
        # 'int2name': [combined['int2name'][i] for i in range(split_idx, data_size)]
    }
    labels = train_data['label']
    data_indices = split_data_dirichlet(labels, n_clients, configs.non_iid_alpha)
    client_dataloaders = []
    for i in range(n_clients):
        client_data = {
            'A_feat': [train_data['A_feat'][idx] for idx in data_indices[i]],
            'V_feat': [train_data['V_feat'][idx] for idx in data_indices[i]],
            'L_feat': [train_data['L_feat'][idx] for idx in data_indices[i]],
            'label': train_data['label'][data_indices[i]],
            'lengths': train_data['lengths'][data_indices[i]],
            # 'int2name': [train_data['int2name'][idx] for idx in data_indices[i]]
        }
        client_dataloaders.append(DataLoader(
            IMEOCAPDataset(client_data), batch_size=configs.batch_size, 
            shuffle=True, collate_fn=collate_fn
        ))
    
    test_dataloader = DataLoader(
        IMEOCAPDataset(test_data), batch_size=configs.batch_size, 
        shuffle=True, collate_fn=collate_fn
    )
    return client_dataloaders, test_dataloader