"""
data_loader.py
Load and partition datasets (e.g., MNLI, SST-2, QQP, QNLI) for federated learning experiments.
Supports IID, mild, and severe non-IID settings.
"""

import random
import torch
from datasets import load_dataset
from collections import defaultdict
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

class TextDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

def split_data_non_iid(dataset, num_clients, non_iid_type='iid'):
    label2data = defaultdict(list)
    for example in dataset:
        label2data[example['label']].append(example)

    clients = [[] for _ in range(num_clients)]

    if non_iid_type == 'iid':
        all_data = sum(label2data.values(), [])
        random.shuffle(all_data)
        split = len(all_data) // num_clients
        for i in range(num_clients):
            clients[i] = all_data[i*split:(i+1)*split]

    elif non_iid_type == 'mild':
        for i, (label, examples) in enumerate(label2data.items()):
            random.shuffle(examples)
            part = len(examples) // num_clients
            for j in range(num_clients):
                clients[j].extend(examples[j*part:(j+1)*part])

    elif non_iid_type == 'severe':
        labels = list(label2data.keys())
        for i in range(num_clients):
            dominant_label = labels[i % len(labels)]
            clients[i].extend(label2data[dominant_label])

    return clients

def get_tokenized_dataset(task, tokenizer, num_clients=3, non_iid_type='iid'):
    dataset = load_dataset("glue", task)
    train_data = dataset['train']

    clients_data = split_data_non_iid(train_data, num_clients, non_iid_type)
    tokenized_clients = []

    for data in clients_data:
        sent1 = [ex['sentence1'] for ex in data]
        sent2 = [ex['sentence2'] for ex in data] if 'sentence2' in data[0] else None
        labels = [ex['label'] for ex in data]

        encodings = tokenizer(sent1, sent2, truncation=True, padding=True)
        tokenized_clients.append(TextDataset(encodings, labels))

    return tokenized_clients
