import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader

def create_federated_loaders(X_main, X_diff, y, edge_ids, batch_size):
    """
    Converts the entire dataset into a list of 9 DataLoaders per edge.

    Args:
        X_main (numpy.ndarray): Original feature data in shape (N, 19, 6)
        X_diff (numpy.ndarray): Difference feature data in shape (N, 19, 6)
        y (numpy.ndarray): Label data in shape (N,)
        edge_ids (numpy.ndarray): Edge IDs to which each data point belongs, in shape (N,)
        batch_size (int): Batch size for the DataLoader

    Returns:
        list: A list containing 9 DataLoaders
    """
    print("\n--- Creating 9 Federated DataLoaders ---")
    
    X_tensor = torch.from_numpy(X_main).float()
    X_diff_tensor = torch.from_numpy(X_diff).float()
    y_tensor = torch.from_numpy(y).long()
    
    edge_loaders = []
    for i in range(9): 
        indices = np.where(edge_ids == i)[0]
        
        if len(indices) == 0:
            print(f"  Edge {i}: No data found, adding None.")
            edge_loaders.append(None)
            continue
            
        edge_dataset = TensorDataset(X_tensor[indices], 
                                     X_diff_tensor[indices], 
                                     y_tensor[indices])
        
        edge_loader = DataLoader(
            dataset=edge_dataset,
            batch_size=batch_size,
            shuffle=True  
        )
        edge_loaders.append(edge_loader)
        print(f"  Edge {i}: Created DataLoader with {len(indices)} samples.")
    return edge_loaders